Extract interaction constraint from split evaluator. (#5034)

*  Extract interaction constraints from split evaluator.

The reason for doing so is mostly for model IO, where num_feature and interaction_constraints are copied in split evaluator. Also interaction constraint by itself is a feature selector, acting like column sampler and it's inefficient to bury it deep in the evaluator chain. Lastly removing one another copied parameter is a win.

*  Enable inc for approx tree method.

As now the implementation is spited up from evaluator class, it's also enabled for approx method.

*  Removing obsoleted code in colmaker.

They are never documented nor actually used in real world. Also there isn't a single test for those code blocks.

*  Unifying the types used for row and column.

As the size of input dataset is marching to billion, incorrect use of int is subject to overflow, also singed integer overflow is undefined behaviour. This PR starts the procedure for unifying used index type to unsigned integers. There's optimization that can utilize this undefined behaviour, but after some testings I don't see the optimization is beneficial to XGBoost.
This commit is contained in:
Jiaming Yuan 2019-11-14 20:11:41 +08:00 committed by GitHub
parent 886bf93ba4
commit 97abcc7ee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 688 additions and 652 deletions

View File

@ -53,6 +53,7 @@
#include "../src/tree/updater_sync.cc" #include "../src/tree/updater_sync.cc"
#include "../src/tree/updater_histmaker.cc" #include "../src/tree/updater_histmaker.cc"
#include "../src/tree/updater_skmaker.cc" #include "../src/tree/updater_skmaker.cc"
#include "../src/tree/constraints.cc"
// linear // linear
#include "../src/linear/linear_updater.cc" #include "../src/linear/linear_updater.cc"

View File

@ -142,6 +142,8 @@ Parameters for Tree Booster
- ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting. - ``grow_histmaker``: distributed tree construction with row-based data splitting based on global proposal of histogram counting.
- ``grow_local_histmaker``: based on local histogram counting. - ``grow_local_histmaker``: based on local histogram counting.
- ``grow_skmaker``: uses the approximate sketching algorithm. - ``grow_skmaker``: uses the approximate sketching algorithm.
- ``grow_quantile_histmaker``: Grow tree using quantized histogram.
- ``grow_gpu_hist``: Grow tree with GPU.
- ``sync``: synchronizes trees in all distributed nodes. - ``sync``: synchronizes trees in all distributed nodes.
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
- ``prune``: prunes the splits where loss < min_split_loss (or gamma). - ``prune``: prunes the splits where loss < min_split_loss (or gamma).

View File

@ -172,9 +172,9 @@ parameter:
early_stopping_rounds = 10) early_stopping_rounds = 10)
**Choice of tree construction algorithm**. To use feature interaction constraints, be sure **Choice of tree construction algorithm**. To use feature interaction constraints, be sure
to set the ``tree_method`` parameter to one of the following: ``exact``, ``hist`` or to set the ``tree_method`` parameter to one of the following: ``exact``, ``hist``,
``gpu_hist``. Support for ``gpu_hist`` is added after (excluding) version 0.90. ``approx`` or ``gpu_hist``. Support for ``gpu_hist`` and ``approx`` is added only in
1.0.0.
************** **************
Advanced topic Advanced topic

View File

@ -100,17 +100,30 @@
/*! \brief namespace of xgboost*/ /*! \brief namespace of xgboost*/
namespace xgboost { namespace xgboost {
/*!
* \brief unsigned integer type used in boost, /*! \brief unsigned integer type used for feature index. */
* used for feature index and row index.
*/
using bst_uint = uint32_t; // NOLINT using bst_uint = uint32_t; // NOLINT
/*! \brief integer type. */
using bst_int = int32_t; // NOLINT using bst_int = int32_t; // NOLINT
/*! \brief long integers */ /*! \brief unsigned long integers */
typedef uint64_t bst_ulong; // NOLINT(*) using bst_ulong = uint64_t;
/*! \brief float type, used for storing statistics */ /*! \brief float type, used for storing statistics */
using bst_float = float; // NOLINT using bst_float = float; // NOLINT
/*! \brief Type for data column (feature) index. */
using bst_feature_t = uint32_t; // NOLINT
/*! \breif Type for data row index.
*
* Be careful `std::size_t' is implementation-defined. Meaning that the binary
* representation of DMatrix might not be portable across platform. Booster model should
* be portable as parameters are floating points.
*/
using bst_row_t = std::size_t; // NOLINT
/*! \brief Type for tree node index. */
using bst_node_t = int32_t; // NOLINT
/*! \brief Type for ranking group index. */
using bst_group_t = uint32_t; // NOLINT
namespace detail { namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation /*! \brief Implementation of gradient statistics pair. Template specialisation
* may be used to overload different gradients types e.g. low precision, high * may be used to overload different gradients types e.g. low precision, high

View File

@ -57,7 +57,7 @@ class MetaInfo {
* \brief the index of begin and end of a group * \brief the index of begin and end of a group
* needed when the learning task is ranking. * needed when the learning task is ranking.
*/ */
std::vector<bst_uint> group_ptr_; std::vector<bst_group_t> group_ptr_;
/*! \brief weights of each instance, optional */ /*! \brief weights of each instance, optional */
HostDeviceVector<bst_float> weights_; HostDeviceVector<bst_float> weights_;
/*! /*!
@ -136,7 +136,7 @@ class MetaInfo {
/*! \brief Element from a sparse vector */ /*! \brief Element from a sparse vector */
struct Entry { struct Entry {
/*! \brief feature index */ /*! \brief feature index */
bst_uint index; bst_feature_t index;
/*! \brief feature value */ /*! \brief feature value */
bst_float fvalue; bst_float fvalue;
/*! \brief default constructor */ /*! \brief default constructor */
@ -146,7 +146,7 @@ struct Entry {
* \param index The feature or row index. * \param index The feature or row index.
* \param fvalue The feature value. * \param fvalue The feature value.
*/ */
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} Entry(bst_feature_t index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */ /*! \brief reversely compare feature values */
inline static bool CmpValue(const Entry& a, const Entry& b) { inline static bool CmpValue(const Entry& a, const Entry& b) {
return a.fvalue < b.fvalue; return a.fvalue < b.fvalue;
@ -174,7 +174,7 @@ struct BatchParam {
class SparsePage { class SparsePage {
public: public:
// Offset for each row. // Offset for each row.
HostDeviceVector<size_t> offset; HostDeviceVector<bst_row_t> offset;
/*! \brief the data of the segments */ /*! \brief the data of the segments */
HostDeviceVector<Entry> data; HostDeviceVector<Entry> data;

View File

@ -21,7 +21,7 @@
namespace xgboost { namespace xgboost {
class TreeUpdater; class TreeUpdater;
namespace gbm { namespace gbm {
class GBTreeModel; struct GBTreeModel;
} // namespace gbm } // namespace gbm
} }

View File

@ -267,7 +267,9 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
data::SimpleCSRSource& mat = *source; data::SimpleCSRSource& mat = *source;
auto& offset_vec = mat.page_.offset.HostVector(); auto& offset_vec = mat.page_.offset.HostVector();
auto& data_vec = mat.page_.data.HostVector(); auto& data_vec = mat.page_.data.HostVector();
common::ParallelGroupBuilder<Entry> builder(&offset_vec, &data_vec); common::ParallelGroupBuilder<
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type>
builder(&offset_vec, &data_vec);
builder.InitBudget(0, nthread); builder.InitBudget(0, nthread);
size_t ncol = nindptr - 1; // NOLINT(*) size_t ncol = nindptr - 1; // NOLINT(*)
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
@ -362,19 +364,20 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
API_END(); API_END();
} }
void PrefixSum(size_t *x, size_t N) { template <typename T>
size_t *suma; void PrefixSum(T *x, size_t N) {
std::vector<T> suma;
#pragma omp parallel #pragma omp parallel
{ {
const int ithread = omp_get_thread_num(); const int ithread = omp_get_thread_num();
const int nthreads = omp_get_num_threads(); const int nthreads = omp_get_num_threads();
#pragma omp single #pragma omp single
{ {
suma = new size_t[nthreads+1]; suma.resize(nthreads+1);
suma[0] = 0; suma[0] = 0;
} }
size_t sum = 0; T sum = 0;
size_t offset = 0; T offset = 0;
#pragma omp for schedule(static) #pragma omp for schedule(static)
for (omp_ulong i = 0; i < N; i++) { for (omp_ulong i = 0; i < N; i++) {
sum += x[i]; sum += x[i];
@ -390,7 +393,6 @@ void PrefixSum(size_t *x, size_t N) {
x[i] += offset; x[i] += offset;
} }
} }
delete[] suma;
} }
XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT

View File

@ -16,6 +16,8 @@
#include <vector> #include <vector>
#include "xgboost/base.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
/*! /*!
@ -23,7 +25,7 @@ namespace common {
* \tparam ValueType type of entries in the sparse matrix * \tparam ValueType type of entries in the sparse matrix
* \tparam SizeType type of the index range holder * \tparam SizeType type of the index range holder
*/ */
template<typename ValueType, typename SizeType = std::size_t> template<typename ValueType, typename SizeType = bst_ulong>
struct ParallelGroupBuilder { struct ParallelGroupBuilder {
public: public:
// parallel group builder of data // parallel group builder of data

View File

@ -421,7 +421,7 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
#pragma omp parallel for num_threads(nthread) schedule(static) #pragma omp parallel for num_threads(nthread) schedule(static)
for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) { for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) {
for (size_t tid = 0; tid < nthread; ++tid) { for (int32_t tid = 0; tid < nthread; ++tid) {
hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch
} }

View File

@ -157,9 +157,21 @@ void HostDeviceVector<T>::SetDevice(int device) const {}
// explicit instantiations are required, as HostDeviceVector isn't header-only // explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>; template class HostDeviceVector<bst_float>;
template class HostDeviceVector<GradientPair>; template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<int>; template class HostDeviceVector<int32_t>; // bst_node_t
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<size_t>; template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
#if defined(__APPLE__)
/*
* On OSX:
*
* typedef unsigned int uint32_t;
* typedef unsigned long long uint64_t;
* typedef unsigned long __darwin_size_t;
*/
template class HostDeviceVector<std::size_t>;
#endif // defined(__APPLE__)
} // namespace xgboost } // namespace xgboost

View File

@ -351,8 +351,20 @@ void HostDeviceVector<T>::Resize(size_t new_size, T v) {
// explicit instantiations are required, as HostDeviceVector isn't header-only // explicit instantiations are required, as HostDeviceVector isn't header-only
template class HostDeviceVector<bst_float>; template class HostDeviceVector<bst_float>;
template class HostDeviceVector<GradientPair>; template class HostDeviceVector<GradientPair>;
template class HostDeviceVector<int>; template class HostDeviceVector<int32_t>; // bst_node_t
template class HostDeviceVector<Entry>; template class HostDeviceVector<Entry>;
template class HostDeviceVector<size_t>; template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
#if defined(__APPLE__)
/*
* On OSX:
*
* typedef unsigned int uint32_t;
* typedef unsigned long long uint64_t;
* typedef unsigned long __darwin_size_t;
*/
template class HostDeviceVector<std::size_t>;
#endif // defined(__APPLE__)
} // namespace xgboost } // namespace xgboost

View File

@ -85,20 +85,20 @@ GlobalRandomEngine& GlobalRandom(); // NOLINT(*)
*/ */
class ColumnSampler { class ColumnSampler {
std::shared_ptr<HostDeviceVector<int>> feature_set_tree_; std::shared_ptr<HostDeviceVector<bst_feature_t>> feature_set_tree_;
std::map<int, std::shared_ptr<HostDeviceVector<int>>> feature_set_level_; std::map<int, std::shared_ptr<HostDeviceVector<bst_feature_t>>> feature_set_level_;
float colsample_bylevel_{1.0f}; float colsample_bylevel_{1.0f};
float colsample_bytree_{1.0f}; float colsample_bytree_{1.0f};
float colsample_bynode_{1.0f}; float colsample_bynode_{1.0f};
GlobalRandomEngine rng_; GlobalRandomEngine rng_;
std::shared_ptr<HostDeviceVector<int>> ColSample( std::shared_ptr<HostDeviceVector<bst_feature_t>> ColSample(
std::shared_ptr<HostDeviceVector<int>> p_features, float colsample) { std::shared_ptr<HostDeviceVector<bst_feature_t>> p_features, float colsample) {
if (colsample == 1.0f) return p_features; if (colsample == 1.0f) return p_features;
const auto& features = p_features->HostVector(); const auto& features = p_features->HostVector();
CHECK_GT(features.size(), 0); CHECK_GT(features.size(), 0);
int n = std::max(1, static_cast<int>(colsample * features.size())); int n = std::max(1, static_cast<int>(colsample * features.size()));
auto p_new_features = std::make_shared<HostDeviceVector<int>>(); auto p_new_features = std::make_shared<HostDeviceVector<bst_feature_t>>();
auto& new_features = *p_new_features; auto& new_features = *p_new_features;
new_features.Resize(features.size()); new_features.Resize(features.size());
std::copy(features.begin(), features.end(), std::copy(features.begin(), features.end(),
@ -147,7 +147,7 @@ class ColumnSampler {
colsample_bynode_ = colsample_bynode; colsample_bynode_ = colsample_bynode;
if (feature_set_tree_ == nullptr) { if (feature_set_tree_ == nullptr) {
feature_set_tree_ = std::make_shared<HostDeviceVector<int>>(); feature_set_tree_ = std::make_shared<HostDeviceVector<bst_feature_t>>();
} }
Reset(); Reset();
@ -178,7 +178,7 @@ class ColumnSampler {
* construction of each tree node, and must be called the same number of times in each * construction of each tree node, and must be called the same number of times in each
* process and with the same parameters to return the same feature set across processes. * process and with the same parameters to return the same feature set across processes.
*/ */
std::shared_ptr<HostDeviceVector<int>> GetFeatureSet(int depth) { std::shared_ptr<HostDeviceVector<bst_feature_t>> GetFeatureSet(int depth) {
if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) { if (colsample_bylevel_ == 1.0f && colsample_bynode_ == 1.0f) {
return feature_set_tree_; return feature_set_tree_;
} }

View File

@ -229,7 +229,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::unique_ptr<dmlc::Parser<uint32_t> > parser( std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str())); dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
DMatrix* dmat; DMatrix* dmat {nullptr};
try { try {
dmat = DMatrix::Create(parser.get(), cache_file, page_size); dmat = DMatrix::Create(parser.get(), cache_file, page_size);
@ -253,9 +253,8 @@ DMatrix* DMatrix::Load(const std::string& uri,
<< "Choosing default parser in dmlc-core. " << "Choosing default parser in dmlc-core. "
<< "Consider providing a uri parameter like: filename?format=csv"; << "Consider providing a uri parameter like: filename?format=csv";
} }
LOG(FATAL) << "Encountered parser error:\n" << e.what();
} }
LOG(FATAL) << "Encountered parser error:\n" << e.what();
} }
if (!silent) { if (!silent) {
@ -361,7 +360,7 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
namespace xgboost { namespace xgboost {
SparsePage SparsePage::GetTranspose(int num_columns) const { SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage transpose; SparsePage transpose;
common::ParallelGroupBuilder<Entry> builder(&transpose.offset.HostVector(), common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
&transpose.data.HostVector()); &transpose.data.HostVector());
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
builder.InitBudget(num_columns, nthread); builder.InitBudget(num_columns, nthread);
@ -424,7 +423,7 @@ void SparsePage::Push(const dmlc::RowBlock<uint32_t>& batch) {
void SparsePage::PushCSC(const SparsePage &batch) { void SparsePage::PushCSC(const SparsePage &batch) {
std::vector<xgboost::Entry>& self_data = data.HostVector(); std::vector<xgboost::Entry>& self_data = data.HostVector();
std::vector<size_t>& self_offset = offset.HostVector(); std::vector<bst_row_t>& self_offset = offset.HostVector();
auto const& other_data = batch.data.ConstHostVector(); auto const& other_data = batch.data.ConstHostVector();
auto const& other_offset = batch.offset.ConstHostVector(); auto const& other_offset = batch.offset.ConstHostVector();
@ -442,7 +441,7 @@ void SparsePage::PushCSC(const SparsePage &batch) {
return; return;
} }
std::vector<size_t> offset(other_offset.size()); std::vector<bst_row_t> offset(other_offset.size());
offset[0] = 0; offset[0] = 0;
std::vector<xgboost::Entry> data(self_data.size() + other_data.size()); std::vector<xgboost::Entry> data(self_data.size() + other_data.size());

View File

@ -29,7 +29,7 @@ namespace data {
template <typename T> template <typename T>
__global__ void CountValidKernel(Columnar<T> const column, __global__ void CountValidKernel(Columnar<T> const column,
bool has_missing, float missing, bool has_missing, float missing,
int32_t* flag, common::Span<size_t> offsets) { int32_t* flag, common::Span<bst_row_t> offsets) {
auto const tid = threadIdx.x + blockDim.x * blockIdx.x; auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
bool const missing_is_nan = common::CheckNAN(missing); bool const missing_is_nan = common::CheckNAN(missing);
@ -59,7 +59,7 @@ __global__ void CountValidKernel(Columnar<T> const column,
template <typename T> template <typename T>
__device__ void AssignValue(T fvalue, int32_t colid, __device__ void AssignValue(T fvalue, int32_t colid,
common::Span<size_t> out_offsets, common::Span<Entry> out_data) { common::Span<bst_row_t> out_offsets, common::Span<Entry> out_data) {
auto const tid = threadIdx.x + blockDim.x * blockIdx.x; auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
int32_t oid = out_offsets[tid]; int32_t oid = out_offsets[tid];
out_data[oid].fvalue = fvalue; out_data[oid].fvalue = fvalue;
@ -70,7 +70,7 @@ __device__ void AssignValue(T fvalue, int32_t colid,
template <typename T> template <typename T>
__global__ void CreateCSRKernel(Columnar<T> const column, __global__ void CreateCSRKernel(Columnar<T> const column,
int32_t colid, bool has_missing, float missing, int32_t colid, bool has_missing, float missing,
common::Span<size_t> offsets, common::Span<Entry> out_data) { common::Span<bst_row_t> offsets, common::Span<Entry> out_data) {
auto const tid = threadIdx.x + blockDim.x * blockIdx.x; auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
if (column.size <= tid) { if (column.size <= tid) {
return; return;
@ -98,7 +98,7 @@ __global__ void CreateCSRKernel(Columnar<T> const column,
template <typename T> template <typename T>
void CountValid(std::vector<Json> const& j_columns, uint32_t column_id, void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
bool has_missing, float missing, bool has_missing, float missing,
HostDeviceVector<size_t>* out_offset, HostDeviceVector<bst_row_t>* out_offset,
dh::caching_device_vector<int32_t>* out_d_flag, dh::caching_device_vector<int32_t>* out_d_flag,
uint32_t* out_n_rows) { uint32_t* out_n_rows) {
uint32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
@ -121,7 +121,7 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
CHECK_EQ(out_offset->Size(), n_rows + 1) CHECK_EQ(out_offset->Size(), n_rows + 1)
<< "All columns should have same number of rows."; << "All columns should have same number of rows.";
common::Span<size_t> s_offsets = out_offset->DeviceSpan(); common::Span<bst_row_t> s_offsets = out_offset->DeviceSpan();
uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
dh::LaunchKernel {kBlocks, kThreads} ( dh::LaunchKernel {kBlocks, kThreads} (
@ -135,7 +135,7 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
template <typename T> template <typename T>
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows, void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
bool has_missing, float missing, bool has_missing, float missing,
dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) { dh::device_vector<bst_row_t>* tmp_offset, common::Span<Entry> s_data) {
uint32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
auto const& j_column = j_columns[column_id]; auto const& j_column = j_columns[column_id];
auto const& column_obj = get<Object const>(j_column); auto const& column_obj = get<Object const>(j_column);
@ -174,13 +174,13 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
info.num_row_ = n_rows; info.num_row_ = n_rows;
auto s_offsets = this->page_.offset.DeviceSpan(); auto s_offsets = this->page_.offset.DeviceSpan();
thrust::device_ptr<size_t> p_offsets(s_offsets.data()); thrust::device_ptr<bst_row_t> p_offsets(s_offsets.data());
CHECK_GE(s_offsets.size(), n_rows + 1); CHECK_GE(s_offsets.size(), n_rows + 1);
thrust::inclusive_scan(p_offsets, p_offsets + n_rows + 1, p_offsets); thrust::inclusive_scan(p_offsets, p_offsets + n_rows + 1, p_offsets);
// Created for building csr matrix, where we need to change index after processing each // Created for building csr matrix, where we need to change index after processing each
// column. // column.
dh::device_vector<size_t> tmp_offset(this->page_.offset.Size()); dh::device_vector<bst_row_t> tmp_offset(this->page_.offset.Size());
dh::safe_cuda(cudaMemcpy(tmp_offset.data().get(), s_offsets.data(), dh::safe_cuda(cudaMemcpy(tmp_offset.data().get(), s_offsets.data(),
s_offsets.size_bytes(), cudaMemcpyDeviceToDevice)); s_offsets.size_bytes(), cudaMemcpyDeviceToDevice));

View File

@ -80,13 +80,13 @@ struct DevicePredictionNode {
struct ElementLoader { struct ElementLoader {
bool use_shared; bool use_shared;
common::Span<const size_t> d_row_ptr; common::Span<const bst_row_t> d_row_ptr;
common::Span<const Entry> d_data; common::Span<const Entry> d_data;
int num_features; int num_features;
float* smem; float* smem;
size_t entry_start; size_t entry_start;
__device__ ElementLoader(bool use_shared, common::Span<const size_t> row_ptr, __device__ ElementLoader(bool use_shared, common::Span<const bst_row_t> row_ptr,
common::Span<const Entry> entry, int num_features, common::Span<const Entry> entry, int num_features,
float* smem, int num_rows, size_t entry_start) float* smem, int num_rows, size_t entry_start)
: use_shared(use_shared), : use_shared(use_shared),
@ -166,7 +166,7 @@ __global__ void PredictKernel(common::Span<const DevicePredictionNode> d_nodes,
common::Span<float> d_out_predictions, common::Span<float> d_out_predictions,
common::Span<size_t> d_tree_segments, common::Span<size_t> d_tree_segments,
common::Span<int> d_tree_group, common::Span<int> d_tree_group,
common::Span<const size_t> d_row_ptr, common::Span<const bst_row_t> d_row_ptr,
common::Span<const Entry> d_data, size_t tree_begin, common::Span<const Entry> d_data, size_t tree_begin,
size_t tree_end, size_t num_features, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start, size_t num_rows, size_t entry_start,

105
src/tree/constraints.cc Normal file
View File

@ -0,0 +1,105 @@
/*!
* Copyright 2018-2019 by Contributors
*/
#include <algorithm>
#include <unordered_set>
#include <vector>
#include "xgboost/span.h"
#include "constraints.h"
#include "param.h"
namespace xgboost {
void FeatureInteractionConstraintHost::Configure(tree::TrainParam const& param,
bst_feature_t const n_features) {
if (param.interaction_constraints.empty()) {
enabled_ = !param.interaction_constraints.empty();
return; // short-circuit if no constraint is specified
}
enabled_ = true;
this->interaction_constraint_str_ = param.interaction_constraints;
this->n_features_ = n_features;
this->Reset();
}
void FeatureInteractionConstraintHost::Reset() {
if (!enabled_) {
return;
}
// Parse interaction constraints
std::istringstream iss(this->interaction_constraint_str_);
dmlc::JSONReader reader(&iss);
// Read std::vector<std::vector<bst_uint>> first and then
// convert to std::vector<std::unordered_set<bst_uint>>
std::vector<std::vector<bst_uint>> tmp;
try {
reader.Read(&tmp);
} catch (dmlc::Error const& e) {
LOG(FATAL) << "Failed to parse feature interaction constraint:\n"
<< this->interaction_constraint_str_ << "\n"
<< "With error:\n" << e.what();
}
for (const auto& e : tmp) {
interaction_constraints_.emplace_back(e.begin(), e.end());
}
// Initialise interaction constraints record with all variables permitted for the first node
node_constraints_.clear();
node_constraints_.resize(1, std::unordered_set<bst_feature_t>());
node_constraints_[0].reserve(n_features_);
for (bst_feature_t i = 0; i < n_features_; ++i) {
node_constraints_[0].insert(i);
}
// Initialise splits record
splits_.clear();
splits_.resize(1, std::unordered_set<bst_feature_t>());
}
void FeatureInteractionConstraintHost::SplitImpl(
bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id, bst_node_t right_id) {
bst_node_t newsize = std::max(left_id, right_id) + 1;
// Record previous splits for child nodes
auto feature_splits = splits_[node_id]; // fid history of current node
feature_splits.insert(feature_id); // add feature of current node
splits_.resize(newsize);
splits_[left_id] = feature_splits;
splits_[right_id] = feature_splits;
// Resize constraints record, initialise all features to be not permitted for new nodes
CHECK_NE(newsize, 0);
node_constraints_.resize(newsize, std::unordered_set<bst_feature_t>());
// Permit features used in previous splits
for (bst_feature_t fid : feature_splits) {
node_constraints_[left_id].insert(fid);
node_constraints_[right_id].insert(fid);
}
// Loop across specified interactions in constraints
for (const auto &constraint : interaction_constraints_) {
// flags whether the specified interaction is still relevant
bst_uint flag = 1;
// Test relevance of specified interaction by checking all previous
// features are included
for (bst_uint checkvar : feature_splits) {
if (constraint.count(checkvar) == 0) {
flag = 0;
break; // interaction is not relevant due to unmet constraint
}
}
// If interaction is still relevant, permit all other features in the
// interaction
if (flag == 1) {
for (bst_uint k : constraint) {
node_constraints_[left_id].insert(k);
node_constraints_[right_id].insert(k);
}
}
}
}
} // namespace xgboost

View File

@ -173,7 +173,7 @@ void FeatureInteractionConstraint::ClearBuffers() {
output_buffer_bits_, input_buffer_bits_); output_buffer_bits_, input_buffer_bits_);
} }
common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) { common::Span<bst_feature_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
if (!has_constraint_) { return {}; } if (!has_constraint_) { return {}; }
CHECK_LT(node_id, s_node_constraints_.size()); CHECK_LT(node_id, s_node_constraints_.size());
@ -184,7 +184,7 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
auto p_result_buffer = result_buffer_.data(); auto p_result_buffer = result_buffer_.data();
LBitField64 node_constraints = s_node_constraints_[node_id]; LBitField64 node_constraints = s_node_constraints_[node_id];
thrust::device_ptr<int32_t> const out_end = thrust::copy_if( thrust::device_ptr<bst_feature_t> const out_end = thrust::copy_if(
thrust::device, thrust::device,
begin, end, begin, end,
p_result_buffer, p_result_buffer,
@ -197,7 +197,7 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
return {s_result_buffer_.data(), s_result_buffer_.data() + n_available}; return {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
} }
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input, __global__ void SetInputBufferKernel(common::Span<bst_feature_t> feature_list_input,
LBitField64 result_buffer_input) { LBitField64 result_buffer_input) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) { if (tid < feature_list_input.size()) {
@ -212,8 +212,8 @@ __global__ void QueryFeatureListKernel(LBitField64 node_constraints,
result_buffer_output &= result_buffer_input; result_buffer_output &= result_buffer_input;
} }
common::Span<int32_t> FeatureInteractionConstraint::Query( common::Span<bst_feature_t> FeatureInteractionConstraint::Query(
common::Span<int32_t> feature_list, int32_t nid) { common::Span<bst_feature_t> feature_list, int32_t nid) {
if (!has_constraint_ || nid == 0) { if (!has_constraint_ || nid == 0) {
return feature_list; return feature_list;
} }
@ -238,7 +238,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
LBitField64 local_result_buffer = output_buffer_bits_; LBitField64 local_result_buffer = output_buffer_bits_;
thrust::device_ptr<int32_t> const out_end = thrust::copy_if( thrust::device_ptr<bst_feature_t> const out_end = thrust::copy_if(
thrust::device, thrust::device,
begin, end, begin, end,
result_buffer_.data(), result_buffer_.data(),
@ -248,7 +248,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
}); });
size_t const n_available = std::distance(result_buffer_.data(), out_end); size_t const n_available = std::distance(result_buffer_.data(), out_end);
common::Span<int32_t> result = common::Span<bst_feature_t> result =
{s_result_buffer_.data(), s_result_buffer_.data() + n_available}; {s_result_buffer_.data(), s_result_buffer_.data() + n_available};
return result; return result;
} }
@ -258,12 +258,12 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
__global__ void RestoreFeatureListFromSetsKernel( __global__ void RestoreFeatureListFromSetsKernel(
LBitField64 feature_buffer, LBitField64 feature_buffer,
int32_t fid, bst_feature_t fid,
common::Span<int32_t> feature_interactions, common::Span<int32_t> feature_interactions,
common::Span<int32_t> feature_interactions_ptr, // of size n interaction set + 1 common::Span<int32_t> feature_interactions_ptr, // of size n interaction set + 1
common::Span<int32_t> interactions_list, common::Span<bst_feature_t> interactions_list,
common::Span<int32_t> interactions_list_ptr) { common::Span<size_t> interactions_list_ptr) {
auto const tid_x = threadIdx.x + blockIdx.x * blockDim.x; auto const tid_x = threadIdx.x + blockIdx.x * blockDim.x;
auto const tid_y = threadIdx.y + blockIdx.y * blockDim.y; auto const tid_y = threadIdx.y + blockIdx.y * blockDim.y;
// painful mapping: fid -> sets related to it -> features related to sets. // painful mapping: fid -> sets related to it -> features related to sets.
@ -312,7 +312,7 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
} }
void FeatureInteractionConstraint::Split( void FeatureInteractionConstraint::Split(
int32_t node_id, int32_t feature_id, int32_t left_id, int32_t right_id) { bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id, bst_node_t right_id) {
if (!has_constraint_) { return; } if (!has_constraint_) { return; }
CHECK_NE(node_id, left_id) CHECK_NE(node_id, left_id)
<< " Split node: " << node_id << " and its left child: " << " Split node: " << node_id << " and its left child: "

View File

@ -108,10 +108,10 @@ struct FeatureInteractionConstraint {
* *
* d_sets_ptr_: |0, 1, 3, 4| * d_sets_ptr_: |0, 1, 3, 4|
*/ */
dh::device_vector<int32_t> d_sets_; dh::device_vector<bst_feature_t> d_sets_;
common::Span<int32_t> s_sets_; common::Span<bst_feature_t> s_sets_;
dh::device_vector<int32_t> d_sets_ptr_; dh::device_vector<size_t> d_sets_ptr_;
common::Span<int32_t> s_sets_ptr_; common::Span<size_t> s_sets_ptr_;
// Allowed features attached to each node, have n_nodes bitfields, // Allowed features attached to each node, have n_nodes bitfields,
// each of size n_features. // each of size n_features.
@ -120,8 +120,8 @@ struct FeatureInteractionConstraint {
common::Span<LBitField64> s_node_constraints_; common::Span<LBitField64> s_node_constraints_;
// buffer storing return feature list from Query, of size n_features. // buffer storing return feature list from Query, of size n_features.
dh::device_vector<int32_t> result_buffer_; dh::device_vector<bst_feature_t> result_buffer_;
common::Span<int32_t> s_result_buffer_; common::Span<bst_feature_t> s_result_buffer_;
// Temp buffers, one bit for each possible feature. // Temp buffers, one bit for each possible feature.
dh::device_vector<LBitField64::value_type> output_buffer_bits_storage_; dh::device_vector<LBitField64::value_type> output_buffer_bits_storage_;
@ -149,7 +149,7 @@ struct FeatureInteractionConstraint {
/*! \brief Reset before constructing a new tree. */ /*! \brief Reset before constructing a new tree. */
void Reset(); void Reset();
/*! \brief Return a list of features given node id */ /*! \brief Return a list of features given node id */
common::Span<int32_t> QueryNode(int32_t nid); common::Span<bst_feature_t> QueryNode(int32_t nid);
/*! /*!
* \brief Return a list of selected features from given feature_list and node id. * \brief Return a list of selected features from given feature_list and node id.
* *
@ -159,9 +159,9 @@ struct FeatureInteractionConstraint {
* \return A list of features picked from `feature_list' that conform to constraints in * \return A list of features picked from `feature_list' that conform to constraints in
* node. * node.
*/ */
common::Span<int32_t> Query(common::Span<int32_t> feature_list, int32_t nid); common::Span<bst_feature_t> Query(common::Span<bst_feature_t> feature_list, int32_t nid);
/*! \brief Apply split for node_id. */ /*! \brief Apply split for node_id. */
void Split(int32_t node_id, int32_t feature_id, int32_t left_id, int32_t right_id); void Split(bst_node_t node_id, bst_feature_t feature_id, bst_node_t left_id, bst_node_t right_id);
}; };
} // namespace xgboost } // namespace xgboost

67
src/tree/constraints.h Normal file
View File

@ -0,0 +1,67 @@
/*!
* Copyright 2018-2019 by Contributors
*/
#ifndef XGBOOST_TREE_CONSTRAINTS_H_
#define XGBOOST_TREE_CONSTRAINTS_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "xgboost/span.h"
#include "xgboost/base.h"
#include "param.h"
namespace xgboost {
/*!
* \brief Feature interaction constraint implementation for CPU tree updaters.
*
* The interface is similiar to the one for GPU Hist.
*/
class FeatureInteractionConstraintHost {
protected:
// interaction_constraints_[constraint_id] contains a single interaction
// constraint, which specifies a group of feature IDs that can interact
// with each other
std::vector< std::unordered_set<bst_feature_t> > interaction_constraints_;
// int_cont_[nid] contains the set of all feature IDs that are allowed to
// be used for a split at node nid
std::vector< std::unordered_set<bst_feature_t> > node_constraints_;
// splits_[nid] contains the set of all feature IDs that have been used for
// splits in node nid and its parents
std::vector< std::unordered_set<bst_feature_t> > splits_;
std::vector<bst_feature_t> return_buffer;
// string passed by user.
std::string interaction_constraint_str_;
// number of features in DMatrix/Booster
bst_feature_t n_features_;
bool enabled_{false};
void SplitImpl(int32_t node_id, bst_feature_t feature_id, bst_node_t left_id,
bst_node_t right_id);
public:
FeatureInteractionConstraintHost() = default;
void Split(int32_t node_id, bst_feature_t feature_id, bst_node_t left_id,
bst_node_t right_id) {
if (!enabled_) {
return;
} else {
this->SplitImpl(node_id, feature_id, left_id, right_id);
}
}
bool Query(bst_node_t nid, bst_feature_t fid) const {
if (!enabled_) { return true; }
return node_constraints_.at(nid).find(fid) != node_constraints_.at(nid).cend();
}
void Reset();
void Configure(tree::TrainParam const& param, bst_feature_t const n_features);
};
} // namespace xgboost
#endif // XGBOOST_TREE_CONSTRAINTS_H_

View File

@ -10,23 +10,22 @@ namespace xgboost {
namespace tree { namespace tree {
struct IndicateLeftTransform { struct IndicateLeftTransform {
RowPartitioner::TreePositionT left_nidx; bst_node_t left_nidx;
explicit IndicateLeftTransform(RowPartitioner::TreePositionT left_nidx) explicit IndicateLeftTransform(bst_node_t left_nidx)
: left_nidx(left_nidx) {} : left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()( __host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const {
const RowPartitioner::TreePositionT& x) const {
return x == left_nidx ? 1 : 0; return x == left_nidx ? 1 : 0;
} }
}; };
/* /*
* position: Position of rows belonged to current split node. * position: Position of rows belonged to current split node.
*/ */
void RowPartitioner::SortPosition(common::Span<TreePositionT> position, void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<TreePositionT> position_out, common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out, common::Span<RowIndexT> ridx_out,
TreePositionT left_nidx, bst_node_t left_nidx,
TreePositionT right_nidx, bst_node_t right_nidx,
int64_t* d_left_count, cudaStream_t stream) { int64_t* d_left_count, cudaStream_t stream) {
// radix sort over 1 bit, see: // radix sort over 1 bit, see:
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html // https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
@ -53,8 +52,8 @@ void RowPartitioner::SortPosition(common::Span<TreePositionT> position,
IndicateLeftTransform is_left(left_nidx); IndicateLeftTransform is_left(left_nidx);
// an iterator that given a old position returns whether it belongs to left or right // an iterator that given a old position returns whether it belongs to left or right
// node. // node.
cub::TransformInputIterator<TreePositionT, IndicateLeftTransform, cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
TreePositionT*> bst_node_t*>
in_itr(d_position_in, is_left); in_itr(d_position_in, is_left);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results); dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
@ -73,7 +72,7 @@ RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
position_a.resize(num_rows); position_a.resize(num_rows);
position_b.resize(num_rows); position_b.resize(num_rows);
ridx = dh::DoubleBuffer<RowIndexT>{&ridx_a, &ridx_b}; ridx = dh::DoubleBuffer<RowIndexT>{&ridx_a, &ridx_b};
position = dh::DoubleBuffer<TreePositionT>{&position_a, &position_b}; position = dh::DoubleBuffer<bst_node_t>{&position_a, &position_b};
ridx_segments.emplace_back(Segment(0, num_rows)); ridx_segments.emplace_back(Segment(0, num_rows));
thrust::sequence( thrust::sequence(
@ -97,7 +96,7 @@ RowPartitioner::~RowPartitioner() {
} }
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows( common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
TreePositionT nidx) { bst_node_t nidx) {
auto segment = ridx_segments.at(nidx); auto segment = ridx_segments.at(nidx);
// Return empty span here as a valid result // Return empty span here as a valid result
// Will error if we try to construct a span from a pointer with size 0 // Will error if we try to construct a span from a pointer with size 0
@ -111,36 +110,35 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx.CurrentSpan(); return ridx.CurrentSpan();
} }
common::Span<const RowPartitioner::TreePositionT> common::Span<const bst_node_t> RowPartitioner::GetPosition() {
RowPartitioner::GetPosition() {
return position.CurrentSpan(); return position.CurrentSpan();
} }
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost( std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
TreePositionT nidx) { bst_node_t nidx) {
auto span = GetRows(nidx); auto span = GetRows(nidx);
std::vector<RowIndexT> rows(span.size()); std::vector<RowIndexT> rows(span.size());
dh::CopyDeviceSpanToVector(&rows, span); dh::CopyDeviceSpanToVector(&rows, span);
return rows; return rows;
} }
std::vector<RowPartitioner::TreePositionT> RowPartitioner::GetPositionHost() { std::vector<bst_node_t> RowPartitioner::GetPositionHost() {
auto span = GetPosition(); auto span = GetPosition();
std::vector<TreePositionT> position(span.size()); std::vector<bst_node_t> position(span.size());
dh::CopyDeviceSpanToVector(&position, span); dh::CopyDeviceSpanToVector(&position, span);
return position; return position;
} }
void RowPartitioner::SortPositionAndCopy(const Segment& segment, void RowPartitioner::SortPositionAndCopy(const Segment& segment,
TreePositionT left_nidx, bst_node_t left_nidx,
TreePositionT right_nidx, bst_node_t right_nidx,
int64_t* d_left_count, int64_t* d_left_count,
cudaStream_t stream) { cudaStream_t stream) {
SortPosition( SortPosition(
// position_in // position_in
common::Span<TreePositionT>(position.Current() + segment.begin, common::Span<bst_node_t>(position.Current() + segment.begin,
segment.Size()), segment.Size()),
// position_out // position_out
common::Span<TreePositionT>(position.other() + segment.begin, common::Span<bst_node_t>(position.other() + segment.begin,
segment.Size()), segment.Size()),
// row index in // row index in
common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()), common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()),

View File

@ -2,6 +2,7 @@
* Copyright 2017-2019 XGBoost contributors * Copyright 2017-2019 XGBoost contributors
*/ */
#pragma once #pragma once
#include "xgboost/base.h"
#include "../../common/device_helpers.cuh" #include "../../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
@ -30,7 +31,6 @@ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment
* partition training rows into different leaf nodes. */ * partition training rows into different leaf nodes. */
class RowPartitioner { class RowPartitioner {
public: public:
using TreePositionT = int32_t;
using RowIndexT = bst_uint; using RowIndexT = bst_uint;
struct Segment; struct Segment;
@ -47,8 +47,8 @@ class RowPartitioner {
std::vector<Segment> ridx_segments; std::vector<Segment> ridx_segments;
dh::caching_device_vector<RowIndexT> ridx_a; dh::caching_device_vector<RowIndexT> ridx_a;
dh::caching_device_vector<RowIndexT> ridx_b; dh::caching_device_vector<RowIndexT> ridx_b;
dh::caching_device_vector<TreePositionT> position_a; dh::caching_device_vector<bst_node_t> position_a;
dh::caching_device_vector<TreePositionT> position_b; dh::caching_device_vector<bst_node_t> position_b;
/*! \brief mapping for node id -> rows. /*! \brief mapping for node id -> rows.
* This looks like: * This looks like:
* node id | 1 | 2 | * node id | 1 | 2 |
@ -56,7 +56,7 @@ class RowPartitioner {
*/ */
dh::DoubleBuffer<RowIndexT> ridx; dh::DoubleBuffer<RowIndexT> ridx;
/*! \brief mapping for row -> node id. */ /*! \brief mapping for row -> node id. */
dh::DoubleBuffer<TreePositionT> position; dh::DoubleBuffer<bst_node_t> position;
dh::caching_device_vector<int64_t> dh::caching_device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position left_counts; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams; std::vector<cudaStream_t> streams;
@ -70,7 +70,7 @@ class RowPartitioner {
/** /**
* \brief Gets the row indices of training instances in a given node. * \brief Gets the row indices of training instances in a given node.
*/ */
common::Span<const RowIndexT> GetRows(TreePositionT nidx); common::Span<const RowIndexT> GetRows(bst_node_t nidx);
/** /**
* \brief Gets all training rows in the set. * \brief Gets all training rows in the set.
@ -80,17 +80,17 @@ class RowPartitioner {
/** /**
* \brief Gets the tree position of all training instances. * \brief Gets the tree position of all training instances.
*/ */
common::Span<const TreePositionT> GetPosition(); common::Span<const bst_node_t> GetPosition();
/** /**
* \brief Convenience method for testing * \brief Convenience method for testing
*/ */
std::vector<RowIndexT> GetRowsHost(TreePositionT nidx); std::vector<RowIndexT> GetRowsHost(bst_node_t nidx);
/** /**
* \brief Convenience method for testing * \brief Convenience method for testing
*/ */
std::vector<TreePositionT> GetPositionHost(); std::vector<bst_node_t> GetPositionHost();
/** /**
* \brief Updates the tree position for set of training instances being split * \brief Updates the tree position for set of training instances being split
@ -105,8 +105,8 @@ class RowPartitioner {
* argument and return the new position for this training instance. * argument and return the new position for this training instance.
*/ */
template <typename UpdatePositionOpT> template <typename UpdatePositionOpT>
void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx, void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
TreePositionT right_nidx, UpdatePositionOpT op) { bst_node_t right_nidx, UpdatePositionOpT op) {
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));
Segment segment = ridx_segments.at(nidx); // rows belongs to node nidx Segment segment = ridx_segments.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx.CurrentSpan(); auto d_ridx = ridx.CurrentSpan();
@ -123,7 +123,7 @@ class RowPartitioner {
// LaunchN starts from zero, so we restore the row index by adding segment.begin // LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin; idx += segment.begin;
RowIndexT ridx = d_ridx[idx]; RowIndexT ridx = d_ridx[idx];
TreePositionT new_position = op(ridx); // new node id bst_node_t new_position = op(ridx); // new node id
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx); AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position; d_position[idx] = new_position;
@ -172,16 +172,16 @@ class RowPartitioner {
* segments. Based on a single pass of exclusive scan, uses iterators to * segments. Based on a single pass of exclusive scan, uses iterators to
* redirect inputs and outputs. * redirect inputs and outputs.
*/ */
void SortPosition(common::Span<TreePositionT> position, void SortPosition(common::Span<bst_node_t> position,
common::Span<TreePositionT> position_out, common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out, TreePositionT left_nidx, common::Span<RowIndexT> ridx_out, bst_node_t left_nidx,
TreePositionT right_nidx, int64_t* d_left_count, bst_node_t right_nidx, int64_t* d_left_count,
cudaStream_t stream = nullptr); cudaStream_t stream = nullptr);
/*! \brief Sort row indices according to position. */ /*! \brief Sort row indices according to position. */
void SortPositionAndCopy(const Segment& segment, TreePositionT left_nidx, void SortPositionAndCopy(const Segment& segment, bst_node_t left_nidx,
TreePositionT right_nidx, int64_t* d_left_count, bst_node_t right_nidx, int64_t* d_left_count,
cudaStream_t stream); cudaStream_t stream);
/** \brief Used to demarcate a contiguous set of row indices associated with /** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */ * some tree node. */

View File

@ -194,7 +194,7 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
"indices of features that are allowed to interact with each other." "indices of features that are allowed to interact with each other."
"See tutorial for more information"); "See tutorial for more information");
DMLC_DECLARE_FIELD(split_evaluator) DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic,interaction") .set_default("elastic_net,monotonic")
.describe("The criteria to use for ranking splits"); .describe("The criteria to use for ranking splits");
// ------ From cpu quantile histogram -------. // ------ From cpu quantile histogram -------.

View File

@ -64,10 +64,6 @@ bst_float SplitEvaluator::ComputeSplitScore(bst_uint nodeid,
return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight); return ComputeSplitScore(nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
} }
bool SplitEvaluator::CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const {
return true;
}
//! \brief Encapsulates the parameters for ElasticNet //! \brief Encapsulates the parameters for ElasticNet
struct ElasticNetParams : public XGBoostParameter<ElasticNetParams> { struct ElasticNetParams : public XGBoostParameter<ElasticNetParams> {
bst_float reg_lambda; bst_float reg_lambda;
@ -159,10 +155,6 @@ class ElasticNet final : public SplitEvaluator {
return w; return w;
} }
bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override {
return true;
}
private: private:
ElasticNetParams params_; ElasticNetParams params_;
@ -307,10 +299,6 @@ class MonotonicConstraint final : public SplitEvaluator {
} }
} }
bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override {
return true;
}
private: private:
MonotonicConstraintParams params_; MonotonicConstraintParams params_;
std::unique_ptr<SplitEvaluator> inner_; std::unique_ptr<SplitEvaluator> inner_;
@ -332,207 +320,5 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
.set_body([](std::unique_ptr<SplitEvaluator> inner) { .set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new MonotonicConstraint(std::move(inner)); return new MonotonicConstraint(std::move(inner));
}); });
/*! \brief Encapsulates the parameters required by the InteractionConstraint
split evaluator
*/
struct InteractionConstraintParams
: public XGBoostParameter<InteractionConstraintParams> {
std::string interaction_constraints;
bst_uint num_feature;
DMLC_DECLARE_PARAMETER(InteractionConstraintParams) {
DMLC_DECLARE_FIELD(interaction_constraints)
.set_default("")
.describe("Constraints for interaction representing permitted interactions."
"The constraints must be specified in the form of a nest list,"
"e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of"
"indices of features that are allowed to interact with each other."
"See tutorial for more information");
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of total features used");
}
};
DMLC_REGISTER_PARAMETER(InteractionConstraintParams);
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
features.
*/
class InteractionConstraint final : public SplitEvaluator {
public:
explicit InteractionConstraint(std::unique_ptr<SplitEvaluator> inner) {
if (!inner) {
LOG(FATAL) << "InteractionConstraint must be given an inner evaluator";
}
inner_ = std::move(inner);
}
void Init(const Args& args)
override {
inner_->Init(args);
params_.UpdateAllowUnknown(args);
Reset();
}
void Reset() override {
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
// Parse interaction constraints
std::istringstream iss(params_.interaction_constraints);
dmlc::JSONReader reader(&iss);
// Read std::vector<std::vector<bst_uint>> first and then
// convert to std::vector<std::unordered_set<bst_uint>>
std::vector<std::vector<bst_uint>> tmp;
try {
reader.Read(&tmp);
} catch (dmlc::Error const& e) {
LOG(FATAL) << "Failed to parse feature interaction constraint:\n"
<< params_.interaction_constraints << "\n"
<< "With error:\n" << e.what();
}
for (const auto& e : tmp) {
interaction_constraints_.emplace_back(e.begin(), e.end());
}
// Initialise interaction constraints record with all variables permitted for the first node
node_constraints_.clear();
node_constraints_.resize(1, std::unordered_set<bst_uint>());
node_constraints_[0].reserve(params_.num_feature);
for (bst_uint i = 0; i < params_.num_feature; ++i) {
node_constraints_[0].insert(i);
}
// Initialise splits record
splits_.clear();
splits_.resize(1, std::unordered_set<bst_uint>());
}
SplitEvaluator* GetHostClone() const override {
if (params_.interaction_constraints.empty()) {
// No interaction constraints specified, just return a clone of inner
return inner_->GetHostClone();
} else {
auto c = new InteractionConstraint(
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
c->params_ = this->params_;
c->Reset();
return c;
}
}
bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats,
bst_float left_weight,
bst_float right_weight) const override {
// Return negative infinity score if feature is not permitted by interaction constraints
if (!CheckInteractionConstraint(featureid, nodeid)) {
return -std::numeric_limits<bst_float>::infinity();
}
// Otherwise, get score from inner evaluator
bst_float score = inner_->ComputeSplitScore(
nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
return score;
}
bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
const override {
return inner_->ComputeScore(parentID, stats, weight);
}
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override {
return inner_->ComputeWeight(parentID, stats);
}
void AddSplit(bst_uint nodeid,
bst_uint leftid,
bst_uint rightid,
bst_uint featureid,
bst_float leftweight,
bst_float rightweight) override {
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
bst_uint newsize = std::max(leftid, rightid) + 1;
// Record previous splits for child nodes
std::unordered_set<bst_uint> feature_splits = splits_[nodeid]; // fid history of current node
feature_splits.insert(featureid); // add feature of current node
splits_.resize(newsize);
splits_[leftid] = feature_splits;
splits_[rightid] = feature_splits;
// Resize constraints record, initialise all features to be not permitted for new nodes
node_constraints_.resize(newsize, std::unordered_set<bst_uint>());
// Permit features used in previous splits
for (bst_uint fid : feature_splits) {
node_constraints_[leftid].insert(fid);
node_constraints_[rightid].insert(fid);
}
// Loop across specified interactions in constraints
for (const auto& constraint : interaction_constraints_) {
bst_uint flag = 1; // flags whether the specified interaction is still relevant
// Test relevance of specified interaction by checking all previous features are included
for (bst_uint checkvar : feature_splits) {
if (constraint.count(checkvar) == 0) {
flag = 0;
break; // interaction is not relevant due to unmet constraint
}
}
// If interaction is still relevant, permit all other features in the interaction
if (flag == 1) {
for (bst_uint k : constraint) {
node_constraints_[leftid].insert(k);
node_constraints_[rightid].insert(k);
}
}
}
}
bool CheckFeatureConstraint(bst_uint nodeid, bst_uint featureid) const override {
return CheckInteractionConstraint(featureid, nodeid);
}
private:
InteractionConstraintParams params_;
std::unique_ptr<SplitEvaluator> inner_;
// interaction_constraints_[constraint_id] contains a single interaction
// constraint, which specifies a group of feature IDs that can interact
// with each other
std::vector< std::unordered_set<bst_uint> > interaction_constraints_;
// int_cont_[nid] contains the set of all feature IDs that are allowed to
// be used for a split at node nid
std::vector< std::unordered_set<bst_uint> > node_constraints_;
// splits_[nid] contains the set of all feature IDs that have been used for
// splits in node nid and its parents
std::vector< std::unordered_set<bst_uint> > splits_;
// Check interaction constraints. Returns true if a given feature ID is
// permissible in a given node; returns false otherwise
inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const {
// short-circuit if no constraint is specified
return (params_.interaction_constraints.empty()
|| node_constraints_.at(nodeid).count(featureid) > 0);
}
};
XGBOOST_REGISTER_SPLIT_EVALUATOR(InteractionConstraint, "interaction")
.describe("Enforces interaction constraints on tree features")
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new InteractionConstraint(std::move(inner));
});
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -68,11 +68,6 @@ class SplitEvaluator {
bst_uint featureid, bst_uint featureid,
bst_float leftweight, bst_float leftweight,
bst_float rightweight); bst_float rightweight);
// Check whether a given feature is feasible for a given node.
// Use this function to narrow the search space for split candidates
virtual bool CheckFeatureConstraint(bst_uint nodeid,
bst_uint featureid) const = 0;
}; };
struct SplitEvaluatorReg struct SplitEvaluatorReg

View File

@ -9,15 +9,18 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/base.h>
#include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "./param.h" #include "xgboost/base.h"
#include "xgboost/tree_updater.h"
#include "param.h"
#include "constraints.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/quantile.h" #include "../common/quantile.h"
@ -75,11 +78,12 @@ class BaseMaker: public TreeUpdater {
return 2; return 2;
} }
} }
inline bst_float MaxValue(bst_uint fid) const { bst_float MaxValue(bst_uint fid) const {
return fminmax_[fid *2 + 1]; return fminmax_[fid *2 + 1];
} }
inline void SampleCol(float p, std::vector<bst_uint> *p_findex) const {
std::vector<bst_uint> &findex = *p_findex; void SampleCol(float p, std::vector<bst_feature_t> *p_findex) const {
std::vector<bst_feature_t> &findex = *p_findex;
findex.clear(); findex.clear();
for (size_t i = 0; i < fminmax_.size(); i += 2) { for (size_t i = 0; i < fminmax_.size(); i += 2) {
const auto fid = static_cast<bst_uint>(i / 2); const auto fid = static_cast<bst_uint>(i / 2);
@ -161,6 +165,7 @@ class BaseMaker: public TreeUpdater {
} }
this->UpdateNode2WorkIndex(tree); this->UpdateNode2WorkIndex(tree);
} }
this->interaction_constraints_.Configure(param_, fmat.Info().num_col_);
} }
/*! \brief update queue expand add in new leaves */ /*! \brief update queue expand add in new leaves */
inline void UpdateQueueExpand(const RegTree &tree) { inline void UpdateQueueExpand(const RegTree &tree) {
@ -462,6 +467,8 @@ class BaseMaker: public TreeUpdater {
*/ */
std::vector<int> position_; std::vector<int> position_;
FeatureInteractionConstraintHost interaction_constraints_;
private: private:
inline void UpdateNode2WorkIndex(const RegTree &tree) { inline void UpdateNode2WorkIndex(const RegTree &tree) {
// update the node2workindex // update the node2workindex

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2019 by Contributors
* \file updater_colmaker.cc * \file updater_colmaker.cc
* \brief use columnwise update to construct a tree * \brief use columnwise update to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
@ -13,6 +13,7 @@
#include <algorithm> #include <algorithm>
#include "param.h" #include "param.h"
#include "constraints.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/bitmap.h" #include "../common/bitmap.h"
#include "split_evaluator.h" #include "split_evaluator.h"
@ -41,11 +42,13 @@ class ColMaker: public TreeUpdater {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
interaction_constraints_.Configure(param_, dmat->Info().num_row_);
// build tree // build tree
for (auto tree : trees) { for (auto tree : trees) {
Builder builder( Builder builder(
param_, param_,
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone())); std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
interaction_constraints_);
builder.Update(gpair->ConstHostVector(), dmat, tree); builder.Update(gpair->ConstHostVector(), dmat, tree);
} }
param_.learning_rate = lr; param_.learning_rate = lr;
@ -56,6 +59,8 @@ class ColMaker: public TreeUpdater {
TrainParam param_; TrainParam param_;
// SplitEvaluator that will be cloned for each Builder // SplitEvaluator that will be cloned for each Builder
std::unique_ptr<SplitEvaluator> spliteval_; std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost interaction_constraints_;
// data structure // data structure
/*! \brief per thread x per node entry to store tmp data */ /*! \brief per thread x per node entry to store tmp data */
struct ThreadEntry { struct ThreadEntry {
@ -89,9 +94,11 @@ class ColMaker: public TreeUpdater {
public: public:
// constructor // constructor
explicit Builder(const TrainParam& param, explicit Builder(const TrainParam& param,
std::unique_ptr<SplitEvaluator> spliteval) std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost _interaction_constraints)
: param_(param), nthread_(omp_get_max_threads()), : param_(param), nthread_(omp_get_max_threads()),
spliteval_(std::move(spliteval)) {} spliteval_(std::move(spliteval)),
interaction_constraints_{std::move(_interaction_constraints)} {}
// update one tree, growing // update one tree, growing
virtual void Update(const std::vector<GradientPair>& gpair, virtual void Update(const std::vector<GradientPair>& gpair,
DMatrix* p_fmat, DMatrix* p_fmat,
@ -116,6 +123,7 @@ class ColMaker: public TreeUpdater {
snode_[nid].best.SplitIndex(), snode_[nid].best.SplitIndex(),
snode_[cleft].weight, snode_[cleft].weight,
snode_[cright].weight); snode_[cright].weight);
interaction_constraints_.Split(nid, snode_[nid].best.SplitIndex(), cleft, cright);
} }
qexpand_ = newnodes; qexpand_ = newnodes;
// if nothing left to be expand, break // if nothing left to be expand, break
@ -251,8 +259,9 @@ class ColMaker: public TreeUpdater {
const std::vector<GradientPair> &gpair) { const std::vector<GradientPair> &gpair) {
// TODO(tqchen): double check stats order. // TODO(tqchen): double check stats order.
const bool ind = col.size() != 0 && col[0].fvalue == col[col.size() - 1].fvalue; const bool ind = col.size() != 0 && col[0].fvalue == col[col.size() - 1].fvalue;
bool need_forward = param_.NeedForwardSearch(p_fmat->GetColDensity(fid), ind); auto col_density = p_fmat->GetColDensity(fid);
bool need_backward = param_.NeedBackwardSearch(p_fmat->GetColDensity(fid), ind); bool need_forward = param_.NeedForwardSearch(col_density, ind);
bool need_backward = param_.NeedBackwardSearch(col_density, ind);
const std::vector<int> &qexpand = qexpand_; const std::vector<int> &qexpand = qexpand_;
#pragma omp parallel #pragma omp parallel
{ {
@ -391,7 +400,7 @@ class ColMaker: public TreeUpdater {
// update enumeration solution // update enumeration solution
inline void UpdateEnumeration(int nid, GradientPair gstats, inline void UpdateEnumeration(int nid, GradientPair gstats,
bst_float fvalue, int d_step, bst_uint fid, bst_float fvalue, int d_step, bst_uint fid,
GradStats &c, std::vector<ThreadEntry> &temp) { // NOLINT(*) GradStats &c, std::vector<ThreadEntry> &temp) const { // NOLINT(*)
// get the statistics of nid // get the statistics of nid
ThreadEntry &e = temp[nid]; ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init // test if first hit, this is fine, because we set 0 during init
@ -404,7 +413,7 @@ class ColMaker: public TreeUpdater {
e.stats.sum_hess >= param_.min_child_weight) { e.stats.sum_hess >= param_.min_child_weight) {
c.SetSubstract(snode_[nid].stats, e.stats); c.SetSubstract(snode_[nid].stats, e.stats);
if (c.sum_hess >= param_.min_child_weight) { if (c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg; bst_float loss_chg {0};
if (d_step == -1) { if (d_step == -1) {
loss_chg = static_cast<bst_float>( loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) - spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
@ -438,12 +447,13 @@ class ColMaker: public TreeUpdater {
} }
} }
// same as EnumerateSplit, with cacheline prefetch optimization // same as EnumerateSplit, with cacheline prefetch optimization
inline void EnumerateSplitCacheOpt(const Entry *begin, void EnumerateSplit(const Entry *begin,
const Entry *end, const Entry *end,
int d_step, int d_step,
bst_uint fid, bst_uint fid,
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair,
std::vector<ThreadEntry> &temp) { // NOLINT(*) std::vector<ThreadEntry> &temp) const { // NOLINT(*)
CHECK(param_.cache_opt) << "Support for `cache_opt' is removed in 1.0.0";
const std::vector<int> &qexpand = qexpand_; const std::vector<int> &qexpand = qexpand_;
// clear all the temp statistics // clear all the temp statistics
for (auto nid : qexpand) { for (auto nid : qexpand) {
@ -474,12 +484,13 @@ class ColMaker: public TreeUpdater {
} }
for (i = 0, p = it; i < kBuffer; ++i, p += d_step) { for (i = 0, p = it; i < kBuffer; ++i, p += d_step) {
const int nid = buf_position[i]; const int nid = buf_position[i];
if (nid < 0) continue; if (nid < 0 || !interaction_constraints_.Query(nid, fid)) { continue; }
this->UpdateEnumeration(nid, buf_gpair[i], this->UpdateEnumeration(nid, buf_gpair[i],
p->fvalue, d_step, p->fvalue, d_step,
fid, c, temp); fid, c, temp);
} }
} }
// finish up the ending piece // finish up the ending piece
for (it = align_end, i = 0; it != end; ++i, it += d_step) { for (it = align_end, i = 0; it != end; ++i, it += d_step) {
buf_position[i] = position_[it->index]; buf_position[i] = position_[it->index];
@ -487,7 +498,7 @@ class ColMaker: public TreeUpdater {
} }
for (it = align_end, i = 0; it != end; ++i, it += d_step) { for (it = align_end, i = 0; it != end; ++i, it += d_step) {
const int nid = buf_position[i]; const int nid = buf_position[i];
if (nid < 0) continue; if (nid < 0 || !interaction_constraints_.Query(nid, fid)) { continue; }
this->UpdateEnumeration(nid, buf_gpair[i], this->UpdateEnumeration(nid, buf_gpair[i],
it->fvalue, d_step, it->fvalue, d_step,
fid, c, temp); fid, c, temp);
@ -518,136 +529,43 @@ class ColMaker: public TreeUpdater {
} }
} }
// enumerate the split values of specific feature
inline void EnumerateSplit(const Entry *begin,
const Entry *end,
int d_step,
bst_uint fid,
const std::vector<GradientPair> &gpair,
const MetaInfo &info,
std::vector<ThreadEntry> &temp) { // NOLINT(*)
// use cacheline aware optimization
if (param_.cache_opt != 0) {
EnumerateSplitCacheOpt(begin, end, d_step, fid, gpair, temp);
return;
}
const std::vector<int> &qexpand = qexpand_;
// clear all the temp statistics
for (auto nid : qexpand) {
temp[nid].stats = GradStats();
}
// left statistics
GradStats c;
for (const Entry *it = begin; it != end; it += d_step) {
const bst_uint ridx = it->index;
const int nid = position_[ridx];
if (nid < 0) continue;
// start working
const bst_float fvalue = it->fvalue;
// get the statistics of nid
ThreadEntry &e = temp[nid];
// test if first hit, this is fine, because we set 0 during init
if (e.stats.Empty()) {
e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue;
} else {
// try to find a split
if (fvalue != e.last_fvalue &&
e.stats.sum_hess >= param_.min_child_weight) {
c.SetSubstract(snode_[nid].stats, e.stats);
if (c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
if (d_step == -1) {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, c, e.stats);
} else {
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
snode_[nid].root_gain);
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
d_step == -1, e.stats, c);
}
}
}
// update the statistics
e.stats.Add(gpair[ridx]);
e.last_fvalue = fvalue;
}
}
// finish updating all statistics, check if it is possible to include all sum statistics
for (int nid : qexpand) {
ThreadEntry &e = temp[nid];
c.SetSubstract(snode_[nid].stats, e.stats);
if (e.stats.sum_hess >= param_.min_child_weight &&
c.sum_hess >= param_.min_child_weight) {
bst_float loss_chg;
GradStats left_sum;
GradStats right_sum;
if (d_step == -1) {
left_sum = c;
right_sum = e.stats;
} else {
left_sum = e.stats;
right_sum = c;
}
loss_chg = static_cast<bst_float>(
spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) -
snode_[nid].root_gain);
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
const bst_float delta = d_step == +1 ? gap: -gap;
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, left_sum, right_sum);
}
}
}
// update the solution candidate // update the solution candidate
virtual void UpdateSolution(const SparsePage &batch, virtual void UpdateSolution(const SparsePage &batch,
const std::vector<int> &feat_set, const std::vector<bst_feature_t> &feat_set,
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair,
DMatrix*p_fmat) { DMatrix*p_fmat) {
const MetaInfo& info = p_fmat->Info();
// start enumeration // start enumeration
const auto num_features = static_cast<bst_omp_uint>(feat_set.size()); const auto num_features = static_cast<bst_omp_uint>(feat_set.size());
#if defined(_OPENMP) #if defined(_OPENMP)
const int batch_size = // NOLINT const int batch_size = // NOLINT
std::max(static_cast<int>(num_features / this->nthread_ / 32), 1); std::max(static_cast<int>(num_features / this->nthread_ / 32), 1);
#endif // defined(_OPENMP) #endif // defined(_OPENMP)
int poption = param_.parallel_option;
if (poption == 2) { CHECK_EQ(param_.parallel_option, 0) << "Support for `parallel_option' is removed in 1.0.0";
poption = static_cast<int>(num_features) * 2 < this->nthread_ ? 1 : 0; {
}
if (poption == 0) {
std::vector<float> densities(num_features); std::vector<float> densities(num_features);
CHECK_EQ(feat_set.size(), num_features); CHECK_EQ(feat_set.size(), num_features);
for (bst_omp_uint i = 0; i < num_features; ++i) { for (bst_omp_uint i = 0; i < num_features; ++i) {
int32_t const fid = feat_set[i]; bst_feature_t const fid = feat_set[i];
densities.at(i) = p_fmat->GetColDensity(fid); densities.at(i) = p_fmat->GetColDensity(fid);
} }
#pragma omp parallel for schedule(dynamic, batch_size) #pragma omp parallel for schedule(dynamic, batch_size)
for (bst_omp_uint i = 0; i < num_features; ++i) { for (bst_omp_uint i = 0; i < num_features; ++i) {
int32_t const fid = feat_set[i]; bst_feature_t const fid = feat_set[i];
int32_t const tid = omp_get_thread_num(); int32_t const tid = omp_get_thread_num();
auto c = batch[fid]; auto c = batch[fid];
const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue; const bool ind = c.size() != 0 && c[0].fvalue == c[c.size() - 1].fvalue;
auto const density = densities[i]; auto const density = densities[i];
if (param_.NeedForwardSearch(density, ind)) { if (param_.NeedForwardSearch(density, ind)) {
this->EnumerateSplit(c.data(), c.data() + c.size(), +1, this->EnumerateSplit(c.data(), c.data() + c.size(), +1,
fid, gpair, info, stemp_[tid]); fid, gpair, stemp_[tid]);
} }
if (param_.NeedBackwardSearch(density, ind)) { if (param_.NeedBackwardSearch(density, ind)) {
this->EnumerateSplit(c.data() + c.size() - 1, c.data() - 1, -1, this->EnumerateSplit(c.data() + c.size() - 1, c.data() - 1, -1,
fid, gpair, info, stemp_[tid]); fid, gpair, stemp_[tid]);
} }
} }
} else {
for (bst_omp_uint fid = 0; fid < num_features; ++fid) {
this->ParallelFindSplit(batch[fid], fid,
p_fmat, gpair);
}
} }
} }
// find splits at current level, do split per level // find splits at current level, do split per level
@ -664,7 +582,7 @@ class ColMaker: public TreeUpdater {
this->SyncBestSolution(qexpand); this->SyncBestSolution(qexpand);
// get the best result, we can synchronize the solution // get the best result, we can synchronize the solution
for (int nid : qexpand) { for (int nid : qexpand) {
NodeEntry &e = snode_[nid]; NodeEntry const &e = snode_[nid];
// now we know the solution in snode[nid], set split // now we know the solution in snode[nid], set split
if (e.best.loss_chg > kRtEps) { if (e.best.loss_chg > kRtEps) {
bst_float left_leaf_weight = bst_float left_leaf_weight =
@ -786,6 +704,8 @@ class ColMaker: public TreeUpdater {
std::vector<int> qexpand_; std::vector<int> qexpand_;
// Evaluates splits and computes optimal weights for a given split // Evaluates splits and computes optimal weights for a given split
std::unique_ptr<SplitEvaluator> spliteval_; std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost interaction_constraints_;
}; };
}; };
@ -810,7 +730,8 @@ class DistColMaker : public ColMaker {
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time"; CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
Builder builder( Builder builder(
param_, param_,
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone())); std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
interaction_constraints_);
// build the tree // build the tree
builder.Update(gpair->ConstHostVector(), dmat, trees[0]); builder.Update(gpair->ConstHostVector(), dmat, trees[0]);
//// prune the tree, note that pruner will sync the tree //// prune the tree, note that pruner will sync the tree
@ -823,8 +744,9 @@ class DistColMaker : public ColMaker {
class Builder : public ColMaker::Builder { class Builder : public ColMaker::Builder {
public: public:
explicit Builder(const TrainParam &param, explicit Builder(const TrainParam &param,
std::unique_ptr<SplitEvaluator> spliteval) std::unique_ptr<SplitEvaluator> spliteval,
: ColMaker::Builder(param, std::move(spliteval)) {} FeatureInteractionConstraintHost _interaction_constraints)
: ColMaker::Builder(param, std::move(spliteval), std::move(_interaction_constraints)) {}
inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) { inline void UpdatePosition(DMatrix* p_fmat, const RegTree &tree) {
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_); const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
@ -931,6 +853,8 @@ class DistColMaker : public ColMaker {
TrainParam param_; TrainParam param_;
// Cloned for each builder instantiation // Cloned for each builder instantiation
std::unique_ptr<SplitEvaluator> spliteval_; std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost interaction_constraints_;
}; };
XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker")

View File

@ -247,7 +247,7 @@ __device__ void EvaluateFeature(
template <int BLOCK_THREADS, typename GradientSumT> template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitKernel( __global__ void EvaluateSplitKernel(
common::Span<const GradientSumT> node_histogram, // histogram for gradients common::Span<const GradientSumT> node_histogram, // histogram for gradients
common::Span<const int> feature_set, // Selected features common::Span<const bst_feature_t> feature_set, // Selected features
DeviceNodeStats node, DeviceNodeStats node,
xgboost::EllpackMatrix matrix, xgboost::EllpackMatrix matrix,
GPUTrainingParam gpu_param, GPUTrainingParam gpu_param,
@ -582,8 +582,8 @@ struct GPUHistMakerDevice {
auto nidx = nidxs[i]; auto nidx = nidxs[i];
auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx)); auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx));
p_feature_set->SetDevice(device_id); p_feature_set->SetDevice(device_id);
auto d_sampled_features = p_feature_set->DeviceSpan(); common::Span<bst_feature_t> d_sampled_features = p_feature_set->DeviceSpan();
common::Span<int32_t> d_feature_set = common::Span<bst_feature_t> d_feature_set =
interaction_constraints.Query(d_sampled_features, nidx); interaction_constraints.Query(d_sampled_features, nidx);
auto d_split_candidates = auto d_split_candidates =
d_split_candidates_all.subspan(i * num_columns, d_feature_set.size()); d_split_candidates_all.subspan(i * num_columns, d_feature_set.size());

View File

@ -1,18 +1,21 @@
/*! /*!
* Copyright 2014 by Contributors * Copyright 2014-2019 by Contributors
* \file updater_histmaker.cc * \file updater_histmaker.cc
* \brief use histogram counting to construct a tree * \brief use histogram counting to construct a tree
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <xgboost/base.h>
#include <xgboost/tree_updater.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "xgboost/tree_updater.h"
#include "xgboost/base.h"
#include "xgboost/logging.h"
#include "../common/quantile.h" #include "../common/quantile.h"
#include "../common/group_data.h" #include "../common/group_data.h"
#include "./updater_basemaker-inl.h" #include "./updater_basemaker-inl.h"
#include "constraints.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -24,12 +27,13 @@ class HistMaker: public BaseMaker {
void Update(HostDeviceVector<GradientPair> *gpair, void Update(HostDeviceVector<GradientPair> *gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override { const std::vector<RegTree*> &trees) override {
interaction_constraints_.Configure(param_, p_fmat->Info().num_col_);
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
// build tree // build tree
for (auto tree : trees) { for (auto tree : trees) {
this->Update(gpair->ConstHostVector(), p_fmat, tree); this->UpdateTree(gpair->ConstHostVector(), p_fmat, tree);
} }
param_.learning_rate = lr; param_.learning_rate = lr;
} }
@ -38,43 +42,39 @@ class HistMaker: public BaseMaker {
} }
protected: protected:
/*! \brief a single histogram */ /*! \brief a single column of histogram cuts */
struct HistUnit { struct HistUnit {
/*! \brief cutting point of histogram, contains maximum point */ /*! \brief cutting point of histogram, contains maximum point */
const bst_float *cut; const float *cut;
/*! \brief content of statistics data */ /*! \brief content of statistics data */
GradStats *data; GradStats *data;
/*! \brief size of histogram */ /*! \brief size of histogram */
unsigned size; uint32_t size;
// default constructor // default constructor
HistUnit() = default; HistUnit() = default;
// constructor // constructor
HistUnit(const bst_float *cut, GradStats *data, unsigned size) HistUnit(const float *cut, GradStats *data, uint32_t size)
: cut(cut), data(data), size(size) {} : cut{cut}, data{data}, size{size} {}
/*! \brief add a histogram to data */ /*! \brief add a histogram to data */
inline void Add(bst_float fv, void Add(float fv, const std::vector<GradientPair> &gpair,
const std::vector<GradientPair> &gpair, const MetaInfo &info, const size_t ridx) {
const MetaInfo &info, unsigned bin = std::upper_bound(cut, cut + size, fv) - cut;
const bst_uint ridx) {
unsigned i = std::upper_bound(cut, cut + size, fv) - cut;
CHECK_NE(size, 0U) << "try insert into size=0"; CHECK_NE(size, 0U) << "try insert into size=0";
CHECK_LT(i, size); CHECK_LT(bin, size);
data[i].Add(gpair[ridx]); data[bin].Add(gpair[ridx]);
} }
}; };
/*! \brief a set of histograms from different index */ /*! \brief a set of histograms from different index */
struct HistSet { struct HistSet {
/*! \brief the index pointer of each histunit */ /*! \brief the index pointer of each histunit */
const unsigned *rptr; const uint32_t *rptr;
/*! \brief cutting points in each histunit */ /*! \brief cutting points in each histunit */
const bst_float *cut; const bst_float *cut;
/*! \brief data in different hist unit */ /*! \brief data in different hist unit */
std::vector<GradStats> data; std::vector<GradStats> data;
/*! \brief */ /*! \brief return a column of histogram cuts */
inline HistUnit operator[](size_t fid) { inline HistUnit operator[](size_t fid) {
return {cut + rptr[fid], return {cut + rptr[fid], &data[0] + rptr[fid], rptr[fid+1] - rptr[fid]};
&data[0] + rptr[fid],
rptr[fid+1] - rptr[fid]};
} }
}; };
// thread workspace // thread workspace
@ -110,26 +110,27 @@ class HistMaker: public BaseMaker {
// reducer for histogram // reducer for histogram
rabit::Reducer<GradStats, GradStats::Reduce> histred_; rabit::Reducer<GradStats, GradStats::Reduce> histred_;
// set of working features // set of working features
std::vector<bst_uint> fwork_set_; std::vector<bst_feature_t> selected_features_;
// update function implementation // update function implementation
virtual void Update(const std::vector<GradientPair> &gpair, virtual void UpdateTree(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
RegTree *p_tree) { RegTree *p_tree) {
CHECK(param_.max_depth > 0) << "max_depth must be larger than 0"; CHECK(param_.max_depth > 0) << "max_depth must be larger than 0";
this->InitData(gpair, *p_fmat, *p_tree); this->InitData(gpair, *p_fmat, *p_tree);
this->InitWorkSet(p_fmat, *p_tree, &fwork_set_); this->InitWorkSet(p_fmat, *p_tree, &selected_features_);
// mark root node as fresh. // mark root node as fresh.
for (int i = 0; i < p_tree->param.num_roots; ++i) { for (int i = 0; i < p_tree->param.num_roots; ++i) {
(*p_tree)[i].SetLeaf(0.0f, 0); (*p_tree)[i].SetLeaf(0.0f, 0);
} }
CHECK_EQ(p_tree->param.num_roots, 1) << "Support for num roots is removed.";
for (int depth = 0; depth < param_.max_depth; ++depth) { for (int depth = 0; depth < param_.max_depth; ++depth) {
// reset and propose candidate split // reset and propose candidate split
this->ResetPosAndPropose(gpair, p_fmat, fwork_set_, *p_tree); this->ResetPosAndPropose(gpair, p_fmat, selected_features_, *p_tree);
// create histogram // create histogram
this->CreateHist(gpair, p_fmat, fwork_set_, *p_tree); this->CreateHist(gpair, p_fmat, selected_features_, *p_tree);
// find split based on histogram statistics // find split based on histogram statistics
this->FindSplit(depth, gpair, p_fmat, fwork_set_, p_tree); this->FindSplit(depth, gpair, p_fmat, selected_features_, p_tree);
// reset position after split // reset position after split
this->ResetPositionAfterSplit(p_fmat, *p_tree); this->ResetPositionAfterSplit(p_fmat, *p_tree);
this->UpdateQueueExpand(*p_tree); this->UpdateQueueExpand(*p_tree);
@ -145,12 +146,12 @@ class HistMaker: public BaseMaker {
// (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly // (2) propose a set of candidate cuts and set wspace.rptr wspace.cut correctly
virtual void ResetPosAndPropose(const std::vector<GradientPair> &gpair, virtual void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector <bst_uint> &fset, const std::vector <bst_feature_t> &fset,
const RegTree &tree) = 0; const RegTree &tree) = 0;
// initialize the current working set of features in this round // initialize the current working set of features in this round
virtual void InitWorkSet(DMatrix *p_fmat, virtual void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
std::vector<bst_uint> *p_fset) { std::vector<bst_feature_t> *p_fset) {
p_fset->resize(tree.param.num_feature); p_fset->resize(tree.param.num_feature);
for (size_t i = 0; i < p_fset->size(); ++i) { for (size_t i = 0; i < p_fset->size(); ++i) {
(*p_fset)[i] = static_cast<unsigned>(i); (*p_fset)[i] = static_cast<unsigned>(i);
@ -162,15 +163,15 @@ class HistMaker: public BaseMaker {
} }
virtual void CreateHist(const std::vector<GradientPair> &gpair, virtual void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector <bst_uint> &fset, const std::vector <bst_feature_t> &fset,
const RegTree &tree) = 0; const RegTree &tree) = 0;
private: private:
inline void EnumerateSplit(const HistUnit &hist, void EnumerateSplit(const HistUnit &hist,
const GradStats &node_sum, const GradStats &node_sum,
bst_uint fid, bst_uint fid,
SplitEntry *best, SplitEntry *best,
GradStats *left_sum) { GradStats *left_sum) const {
if (hist.size == 0) return; if (hist.size == 0) return;
double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess()); double root_gain = CalcGain(param_, node_sum.GetGrad(), node_sum.GetHess());
@ -203,12 +204,13 @@ class HistMaker: public BaseMaker {
} }
} }
} }
inline void FindSplit(int depth,
void FindSplit(int depth,
const std::vector<GradientPair> &gpair, const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector <bst_uint> &fset, const std::vector <bst_feature_t> &feature_set,
RegTree *p_tree) { RegTree *p_tree) {
const size_t num_feature = fset.size(); const size_t num_feature = feature_set.size();
// get the best split condition for each node // get the best split condition for each node
std::vector<SplitEntry> sol(qexpand_.size()); std::vector<SplitEntry> sol(qexpand_.size());
std::vector<GradStats> left_sum(qexpand_.size()); std::vector<GradStats> left_sum(qexpand_.size());
@ -219,15 +221,20 @@ class HistMaker: public BaseMaker {
CHECK_EQ(node2workindex_[nid], static_cast<int>(wid)); CHECK_EQ(node2workindex_[nid], static_cast<int>(wid));
SplitEntry &best = sol[wid]; SplitEntry &best = sol[wid];
GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
for (size_t i = 0; i < fset.size(); ++i) { for (size_t i = 0; i < feature_set.size(); ++i) {
// Query is thread safe as it's a const function.
if (!this->interaction_constraints_.Query(nid, feature_set[i])) {
continue;
}
EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)], EnumerateSplit(this->wspace_.hset[0][i + wid * (num_feature+1)],
node_sum, fset[i], &best, &left_sum[wid]); node_sum, feature_set[i], &best, &left_sum[wid]);
} }
} }
// get the best result, we can synchronize the solution // get the best result, we can synchronize the solution
for (bst_omp_uint wid = 0; wid < nexpand; ++wid) { for (bst_omp_uint wid = 0; wid < nexpand; ++wid) {
const int nid = qexpand_[wid]; const bst_node_t nid = qexpand_[wid];
const SplitEntry &best = sol[wid]; SplitEntry const& best = sol[wid];
const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0]; const GradStats &node_sum = wspace_.hset[0][num_feature + wid * (num_feature + 1)].data[0];
this->SetStats(p_tree, nid, node_sum); this->SetStats(p_tree, nid, node_sum);
// set up the values // set up the values
@ -246,11 +253,13 @@ class HistMaker: public BaseMaker {
best.DefaultLeft(), base_weight, left_leaf_weight, best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg, right_leaf_weight, best.loss_chg,
node_sum.sum_hess); node_sum.sum_hess);
// right side sum
GradStats right_sum; GradStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]); right_sum.SetSubstract(node_sum, left_sum[wid]);
this->SetStats(p_tree, (*p_tree)[nid].LeftChild(), left_sum[wid]); auto left_child = (*p_tree)[nid].LeftChild();
this->SetStats(p_tree, (*p_tree)[nid].RightChild(), right_sum); auto right_child = (*p_tree)[nid].RightChild();
this->SetStats(p_tree, left_child, left_sum[wid]);
this->SetStats(p_tree, right_child, right_sum);
this->interaction_constraints_.Split(nid, best.SplitIndex(), left_child, right_child);
} else { } else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate); (*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
} }
@ -314,7 +323,7 @@ class CQHistMaker: public HistMaker {
// initialize the work set of tree // initialize the work set of tree
void InitWorkSet(DMatrix *p_fmat, void InitWorkSet(DMatrix *p_fmat,
const RegTree &tree, const RegTree &tree,
std::vector<bst_uint> *p_fset) override { std::vector<bst_feature_t> *p_fset) override {
if (p_fmat != cache_dmatrix_) { if (p_fmat != cache_dmatrix_) {
feat_helper_.InitByCol(p_fmat, tree); feat_helper_.InitByCol(p_fmat, tree);
cache_dmatrix_ = p_fmat; cache_dmatrix_ = p_fmat;
@ -325,7 +334,7 @@ class CQHistMaker: public HistMaker {
// code to create histogram // code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair, void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<bst_uint> &fset, const std::vector<bst_feature_t> &fset,
const RegTree &tree) override { const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info(); const MetaInfo &info = p_fmat->Info();
// fill in reverse map // fill in reverse map
@ -365,7 +374,6 @@ class CQHistMaker: public HistMaker {
} }
}; };
// sync the histogram // sync the histogram
// if it is C++11, use lazy evaluation for Allreduce
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data), this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
this->wspace_.hset[0].data.size(), lazy_get_hist); this->wspace_.hset[0].data.size(), lazy_get_hist);
} }
@ -376,7 +384,7 @@ class CQHistMaker: public HistMaker {
} }
void ResetPosAndPropose(const std::vector<GradientPair> &gpair, void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<bst_uint> &fset, const std::vector<bst_feature_t> &fset,
const RegTree &tree) override { const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info(); const MetaInfo &info = p_fmat->Info();
// fill in reverse map // fill in reverse map
@ -485,7 +493,7 @@ class CQHistMaker: public HistMaker {
const SparsePage::Inst &col, const SparsePage::Inst &col,
const MetaInfo &info, const MetaInfo &info,
const RegTree &tree, const RegTree &tree,
const std::vector<bst_uint> &fset, const std::vector<bst_feature_t> &fset,
bst_uint fid_offset, bst_uint fid_offset,
std::vector<HistEntry> *p_temp) { std::vector<HistEntry> *p_temp) {
if (col.size() == 0) return; if (col.size() == 0) return;
@ -612,7 +620,7 @@ class CQHistMaker: public HistMaker {
// temp space to map feature id to working index // temp space to map feature id to working index
std::vector<int> feat2workindex_; std::vector<int> feat2workindex_;
// set of index from fset that are current work set // set of index from fset that are current work set
std::vector<bst_uint> work_set_; std::vector<bst_feature_t> work_set_;
// set of index from that are split candidates. // set of index from that are split candidates.
std::vector<bst_uint> fsplit_set_; std::vector<bst_uint> fsplit_set_;
// thread temp data // thread temp data
@ -641,7 +649,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
protected: protected:
void ResetPosAndPropose(const std::vector<GradientPair> &gpair, void ResetPosAndPropose(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<bst_uint> &fset, const std::vector<bst_feature_t> &fset,
const RegTree &tree) override { const RegTree &tree) override {
if (this->qexpand_.size() == 1) { if (this->qexpand_.size() == 1) {
cached_rptr_.clear(); cached_rptr_.clear();
@ -672,7 +680,7 @@ class GlobalProposalHistMaker: public CQHistMaker {
// code to create histogram // code to create histogram
void CreateHist(const std::vector<GradientPair> &gpair, void CreateHist(const std::vector<GradientPair> &gpair,
DMatrix *p_fmat, DMatrix *p_fmat,
const std::vector<bst_uint> &fset, const std::vector<bst_feature_t> &fset,
const RegTree &tree) override { const RegTree &tree) override {
const MetaInfo &info = p_fmat->Info(); const MetaInfo &info = p_fmat->Info();
// fill in reverse map // fill in reverse map
@ -692,7 +700,8 @@ class GlobalProposalHistMaker: public CQHistMaker {
this->SetDefaultPostion(p_fmat, tree); this->SetDefaultPostion(p_fmat, tree);
this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(), this->work_set_.insert(this->work_set_.end(), this->fsplit_set_.begin(),
this->fsplit_set_.end()); this->fsplit_set_.end());
std::sort(this->work_set_.begin(), this->work_set_.end()); XGBOOST_PARALLEL_SORT(this->work_set_.begin(), this->work_set_.end(),
std::less<decltype(this->work_set_)::value_type>{});
this->work_set_.resize( this->work_set_.resize(
std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin()); std::unique(this->work_set_.begin(), this->work_set_.end()) - this->work_set_.begin());
@ -740,6 +749,7 @@ XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker")
return new CQHistMaker(); return new CQHistMaker();
}); });
// The updater for approx tree method.
XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker")
.describe("Tree constructor that uses approximate global of histogram construction.") .describe("Tree constructor that uses approximate global of histogram construction.")
.set_body([]() { .set_body([]() {

View File

@ -22,6 +22,7 @@
#include "./param.h" #include "./param.h"
#include "./updater_quantile_hist.h" #include "./updater_quantile_hist.h"
#include "./split_evaluator.h" #include "./split_evaluator.h"
#include "constraints.h"
#include "../common/random.h" #include "../common/random.h"
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "../common/row_set.h" #include "../common/row_set.h"
@ -65,12 +66,14 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
int_constraint_.Configure(param_, dmat->Info().num_col_);
// build tree // build tree
if (!builder_) { if (!builder_) {
builder_.reset(new Builder( builder_.reset(new Builder(
param_, param_,
std::move(pruner_), std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()))); std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_));
} }
for (auto tree : trees) { for (auto tree : trees) {
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree); builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
@ -170,6 +173,8 @@ void QuantileHistMaker::Builder::BuildNodeStats(
auto parent_split_feature_id = snode_[parent_id].best.SplitIndex(); auto parent_split_feature_id = snode_[parent_id].best.SplitIndex();
spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id, spliteval_->AddSplit(parent_id, left_sibling_id, nid, parent_split_feature_id,
snode_[left_sibling_id].weight, snode_[nid].weight); snode_[left_sibling_id].weight, snode_[nid].weight);
interaction_constraints_.Split(parent_id, parent_split_feature_id,
left_sibling_id, nid);
} }
} }
builder_monitor_.Stop("BuildNodeStats"); builder_monitor_.Stop("BuildNodeStats");
@ -298,6 +303,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
bst_uint featureid = snode_[nid].best.SplitIndex(); bst_uint featureid = snode_[nid].best.SplitIndex();
spliteval_->AddSplit(nid, cleft, cright, featureid, spliteval_->AddSplit(nid, cleft, cright, featureid,
snode_[cleft].weight, snode_[cright].weight); snode_[cleft].weight, snode_[cright].weight);
interaction_constraints_.Split(nid, featureid, cleft, cright);
this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree); this->EvaluateSplit(cleft, gmat, hist_, *p_fmat, *p_tree);
this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree); this->EvaluateSplit(cright, gmat, hist_, *p_fmat, *p_tree);
@ -325,6 +331,7 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
const std::vector<GradientPair>& gpair_h = gpair->ConstHostVector(); const std::vector<GradientPair>& gpair_h = gpair->ConstHostVector();
spliteval_->Reset(); spliteval_->Reset();
interaction_constraints_.Reset();
this->InitData(gmat, gpair_h, *p_fmat, *p_tree); this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
@ -457,7 +464,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
} }
bool has_neg_hess = false; bool has_neg_hess = false;
for (size_t tid = 0; tid < this->nthread_; ++tid) { for (int32_t tid = 0; tid < this->nthread_; ++tid) {
if (p_buff[tid]) { if (p_buff[tid]) {
has_neg_hess = true; has_neg_hess = true;
} }
@ -561,8 +568,8 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
// start enumeration // start enumeration
const MetaInfo& info = fmat.Info(); const MetaInfo& info = fmat.Info();
auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid)); auto p_feature_set = column_sampler_.GetFeatureSet(tree.GetDepth(nid));
const auto& feature_set = p_feature_set->HostVector(); auto const& feature_set = p_feature_set->HostVector();
const auto nfeature = static_cast<bst_uint>(feature_set.size()); const auto nfeature = static_cast<bst_feature_t>(feature_set.size());
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
best_split_tloc_.resize(nthread); best_split_tloc_.resize(nthread);
#pragma omp parallel for schedule(static) num_threads(nthread) #pragma omp parallel for schedule(static) num_threads(nthread)
@ -576,9 +583,7 @@ void QuantileHistMaker::Builder::EvaluateSplit(const int nid,
const auto feature_id = static_cast<bst_uint>(feature_set[i]); const auto feature_id = static_cast<bst_uint>(feature_set[i]);
const auto tid = static_cast<unsigned>(omp_get_thread_num()); const auto tid = static_cast<unsigned>(omp_get_thread_num());
const auto node_id = static_cast<bst_uint>(nid); const auto node_id = static_cast<bst_uint>(nid);
// Narrow search space by dropping features that are not feasible under the if (interaction_constraints_.Query(node_id, feature_id)) {
// given set of constraints (e.g. feature interaction constraints)
if (spliteval_->CheckFeatureConstraint(node_id, feature_id)) {
this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info, this->EnumerateSplit(-1, gmat, node_hist, snode_[nid], info,
&best_split_tloc_[tid], feature_id, node_id); &best_split_tloc_[tid], feature_id, node_id);
this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info, this->EnumerateSplit(+1, gmat, node_hist, snode_[nid], info,

View File

@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "constraints.h"
#include "./param.h" #include "./param.h"
#include "./split_evaluator.h" #include "./split_evaluator.h"
#include "../common/random.h" #include "../common/random.h"
@ -123,10 +124,11 @@ class QuantileHistMaker: public TreeUpdater {
// constructor // constructor
explicit Builder(const TrainParam& param, explicit Builder(const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner, std::unique_ptr<TreeUpdater> pruner,
std::unique_ptr<SplitEvaluator> spliteval) std::unique_ptr<SplitEvaluator> spliteval,
FeatureInteractionConstraintHost int_constraints_)
: param_(param), pruner_(std::move(pruner)), : param_(param), pruner_(std::move(pruner)),
spliteval_(std::move(spliteval)), p_last_tree_(nullptr), spliteval_(std::move(spliteval)), interaction_constraints_{int_constraints_},
p_last_fmat_(nullptr) { p_last_tree_(nullptr), p_last_fmat_(nullptr) {
builder_monitor_.Init("Quantile::Builder"); builder_monitor_.Init("Quantile::Builder");
} }
// update one tree, growing // update one tree, growing
@ -296,6 +298,7 @@ class QuantileHistMaker: public TreeUpdater {
GHistBuilder hist_builder_; GHistBuilder hist_builder_;
std::unique_ptr<TreeUpdater> pruner_; std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<SplitEvaluator> spliteval_; std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost interaction_constraints_;
// back pointers to tree and data matrix // back pointers to tree and data matrix
const RegTree* p_last_tree_; const RegTree* p_last_tree_;
@ -321,6 +324,7 @@ class QuantileHistMaker: public TreeUpdater {
std::unique_ptr<Builder> builder_; std::unique_ptr<Builder> builder_;
std::unique_ptr<TreeUpdater> pruner_; std::unique_ptr<TreeUpdater> pruner_;
std::unique_ptr<SplitEvaluator> spliteval_; std::unique_ptr<SplitEvaluator> spliteval_;
FeatureInteractionConstraintHost int_constraint_;
}; };
} // namespace tree } // namespace tree

View File

@ -56,7 +56,7 @@ void TestDeviceSketch(bool use_external_memory) {
size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu); size_t row_stride = DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &hmat_gpu);
// compare the row stride with the one obtained from the dmatrix // compare the row stride with the one obtained from the dmatrix
size_t expected_row_stride = 0; bst_row_t expected_row_stride = 0;
for (const auto &batch : dmat->get()->GetBatches<xgboost::SparsePage>()) { for (const auto &batch : dmat->get()->GetBatches<xgboost::SparsePage>()) {
const auto &offset_vec = batch.offset.ConstHostVector(); const auto &offset_vec = batch.offset.ConstHostVector();
for (int i = 1; i <= offset_vec.size() -1; ++i) { for (int i = 1; i <= offset_vec.size() -1; ++i) {

View File

@ -55,7 +55,7 @@ TEST(ColumnSampler, ThreadSynchronisation) {
int n = 128; int n = 128;
size_t iterations = 10; size_t iterations = 10;
size_t levels = 5; size_t levels = 5;
std::vector<int> reference_result; std::vector<bst_feature_t> reference_result;
bool success = bool success =
true; // Cannot use google test asserts in multithreaded region true; // Cannot use google test asserts in multithreaded region
#pragma omp parallel num_threads(num_threads) #pragma omp parallel num_threads(num_threads)

View File

@ -9,7 +9,7 @@
namespace xgboost { namespace xgboost {
TEST(SparsePage, PushCSC) { TEST(SparsePage, PushCSC) {
std::vector<size_t> offset {0}; std::vector<bst_row_t> offset {0};
std::vector<Entry> data; std::vector<Entry> data;
SparsePage page; SparsePage page;
page.offset.HostVector() = offset; page.offset.HostVector() = offset;

View File

@ -99,7 +99,7 @@ TEST(MetaInfo, LoadQid) {
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12}; const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
CHECK(info.group_ptr_ == expected_group_ptr); CHECK(info.group_ptr_ == expected_group_ptr);
const std::vector<size_t> expected_offset{ const std::vector<xgboost::bst_row_t> expected_offset{
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
}; };
const std::vector<xgboost::Entry> expected_data{ const std::vector<xgboost::Entry> expected_data{

View File

@ -249,7 +249,7 @@ inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
0.26f, 0.71f, 1.83f}); 0.26f, 0.71f, 1.83f});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
size_t row_stride = 0; bst_row_t row_stride = 0;
const auto &offset_vec = batch.offset.ConstHostVector(); const auto &offset_vec = batch.offset.ConstHostVector();
for (size_t i = 1; i < offset_vec.size(); ++i) { for (size_t i = 1; i < offset_vec.size(); ++i) {
row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]); row_stride = std::max(row_stride, offset_vec[i] - offset_vec[i-1]);

View File

@ -8,6 +8,7 @@
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
xgboost::Args args {{"verbosity", "2"}}; xgboost::Args args {{"verbosity", "2"}};
xgboost::ConsoleLogger::Configure(args); xgboost::ConsoleLogger::Configure(args);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe"; testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();

View File

@ -91,7 +91,7 @@ void TestUpdatePosition() {
EXPECT_EQ(rp.GetRows(3).size(), 2); EXPECT_EQ(rp.GetRows(3).size(), 2);
EXPECT_EQ(rp.GetRows(4).size(), 3); EXPECT_EQ(rp.GetRows(4).size(), 3);
// Check position is as expected // Check position is as expected
EXPECT_EQ(rp.GetPositionHost(), std::vector<RowPartitioner::TreePositionT>({3,3,4,4,4,2,2,2,2,2})); EXPECT_EQ(rp.GetPositionHost(), std::vector<bst_node_t>({3,3,4,4,4,2,2,2,2,2}));
} }
TEST(RowPartitioner, Basic) { TestUpdatePosition(); } TEST(RowPartitioner, Basic) { TestUpdatePosition(); }

View File

@ -0,0 +1,60 @@
#include <gtest/gtest.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
#include "../../../src/tree/constraints.h"
namespace xgboost {
namespace tree {
TEST(CPUFeatureInteractionConstraint, Empty) {
TrainParam param;
param.UpdateAllowUnknown(Args{});
bst_feature_t constexpr kFeatures = 6;
FeatureInteractionConstraintHost constraints;
constraints.Configure(param, kFeatures);
// no-op
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
common::Span<bst_feature_t> s_input_feature_list = common::Span<bst_feature_t>{h_input_feature_list};
for (auto f : h_input_feature_list) {
constraints.Query(f, 1);
}
// no-op
ASSERT_TRUE(constraints.Query(94389, 12309));
}
TEST(CPUFeatureInteractionConstraint, Basic) {
std::string const constraints_str = R"constraint([[1, 2], [2, 3, 4]])constraint";
std::vector<std::pair<std::string, std::string>> args{
{"interaction_constraints", constraints_str}};
TrainParam param;
param.interaction_constraints = constraints_str;
bst_feature_t constexpr kFeatures = 6;
FeatureInteractionConstraintHost constraints;
constraints.Configure(param, kFeatures);
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
std::vector<bst_feature_t> h_input_feature_list{0, 1, 2, 3, 4, 5};
ASSERT_TRUE(constraints.Query(1, 1));
ASSERT_TRUE(constraints.Query(1, 2));
ASSERT_TRUE(constraints.Query(1, 3));
ASSERT_TRUE(constraints.Query(1, 4));
ASSERT_FALSE(constraints.Query(1, 0));
ASSERT_FALSE(constraints.Query(1, 5));
}
} // namespace tree
} // namespace xgboost

View File

@ -19,13 +19,13 @@ struct FConstraintWrapper : public FeatureInteractionConstraint {
common::Span<LBitField64> GetNodeConstraints() { common::Span<LBitField64> GetNodeConstraints() {
return FeatureInteractionConstraint::s_node_constraints_; return FeatureInteractionConstraint::s_node_constraints_;
} }
FConstraintWrapper(tree::TrainParam param, int32_t n_features) : FConstraintWrapper(tree::TrainParam param, bst_feature_t n_features) :
FeatureInteractionConstraint(param, n_features) {} FeatureInteractionConstraint(param, n_features) {}
dh::device_vector<int32_t> const& GetDSets() const { dh::device_vector<bst_feature_t> const& GetDSets() const {
return d_sets_; return d_sets_;
} }
dh::device_vector<int32_t> const& GetDSetsPtr() const { dh::device_vector<size_t> const& GetDSetsPtr() const {
return d_sets_ptr_; return d_sets_ptr_;
} }
}; };
@ -65,7 +65,7 @@ void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
} // anonymous namespace } // anonymous namespace
TEST(FeatureInteractionConstraint, Init) { TEST(GPUFeatureInteractionConstraint, Init) {
{ {
int32_t constexpr kFeatures = 6; int32_t constexpr kFeatures = 6;
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
@ -123,7 +123,7 @@ TEST(FeatureInteractionConstraint, Init) {
} }
} }
TEST(FeatureInteractionConstraint, Split) { TEST(GPUFeatureInteractionConstraint, Split) {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; int32_t constexpr kFeatures = 6;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
@ -152,9 +152,9 @@ TEST(FeatureInteractionConstraint, Split) {
} }
} }
TEST(FeatureInteractionConstraint, QueryNode) { TEST(GPUFeatureInteractionConstraint, QueryNode) {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; bst_feature_t constexpr kFeatures = 6;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
{ {
@ -165,9 +165,9 @@ TEST(FeatureInteractionConstraint, QueryNode) {
{ {
constraints.Split(/*node_id=*/ 0, /*feature_id=*/ 1, 1, 2); constraints.Split(/*node_id=*/ 0, /*feature_id=*/ 1, 1, 2);
auto span = constraints.QueryNode(0); auto span = constraints.QueryNode(0);
std::vector<int32_t> h_result (span.size()); std::vector<bst_feature_t> h_result (span.size());
thrust::copy(thrust::device_ptr<int32_t>(span.data()), thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
thrust::device_ptr<int32_t>(span.data() + span.size()), thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
h_result.begin()); h_result.begin());
ASSERT_EQ(h_result.size(), 2); ASSERT_EQ(h_result.size(), 2);
ASSERT_EQ(h_result[0], 1); ASSERT_EQ(h_result[0], 1);
@ -177,9 +177,9 @@ TEST(FeatureInteractionConstraint, QueryNode) {
{ {
constraints.Split(1, /*feature_id=*/0, 3, 4); constraints.Split(1, /*feature_id=*/0, 3, 4);
auto span = constraints.QueryNode(1); auto span = constraints.QueryNode(1);
std::vector<int32_t> h_result (span.size()); std::vector<bst_feature_t> h_result (span.size());
thrust::copy(thrust::device_ptr<int32_t>(span.data()), thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
thrust::device_ptr<int32_t>(span.data() + span.size()), thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
h_result.begin()); h_result.begin());
ASSERT_EQ(h_result.size(), 3); ASSERT_EQ(h_result.size(), 3);
ASSERT_EQ(h_result[0], 0); ASSERT_EQ(h_result[0], 0);
@ -189,8 +189,8 @@ TEST(FeatureInteractionConstraint, QueryNode) {
// same as parent // same as parent
span = constraints.QueryNode(3); span = constraints.QueryNode(3);
h_result.resize(span.size()); h_result.resize(span.size());
thrust::copy(thrust::device_ptr<int32_t>(span.data()), thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
thrust::device_ptr<int32_t>(span.data() + span.size()), thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
h_result.begin()); h_result.begin());
ASSERT_EQ(h_result.size(), 3); ASSERT_EQ(h_result.size(), 3);
ASSERT_EQ(h_result[0], 0); ASSERT_EQ(h_result[0], 0);
@ -204,9 +204,9 @@ TEST(FeatureInteractionConstraint, QueryNode) {
FConstraintWrapper large_features(large_param, 256); FConstraintWrapper large_features(large_param, 256);
large_features.Split(0, 139, 1, 2); large_features.Split(0, 139, 1, 2);
auto span = large_features.QueryNode(0); auto span = large_features.QueryNode(0);
std::vector<int32_t> h_result (span.size()); std::vector<bst_feature_t> h_result (span.size());
thrust::copy(thrust::device_ptr<int32_t>(span.data()), thrust::copy(thrust::device_ptr<bst_feature_t>(span.data()),
thrust::device_ptr<int32_t>(span.data() + span.size()), thrust::device_ptr<bst_feature_t>(span.data() + span.size()),
h_result.begin()); h_result.begin());
ASSERT_EQ(h_result.size(), 3); ASSERT_EQ(h_result.size(), 3);
ASSERT_EQ(h_result[0], 1); ASSERT_EQ(h_result[0], 1);
@ -217,10 +217,10 @@ TEST(FeatureInteractionConstraint, QueryNode) {
namespace { namespace {
void CompareFeatureList(common::Span<int32_t> s_output, std::vector<int32_t> solution) { void CompareFeatureList(common::Span<bst_feature_t> s_output, std::vector<bst_feature_t> solution) {
std::vector<int32_t> h_output(s_output.size()); std::vector<bst_feature_t> h_output(s_output.size());
thrust::copy(thrust::device_ptr<int32_t>(s_output.data()), thrust::copy(thrust::device_ptr<bst_feature_t>(s_output.data()),
thrust::device_ptr<int32_t>(s_output.data() + s_output.size()), thrust::device_ptr<bst_feature_t>(s_output.data() + s_output.size()),
h_output.begin()); h_output.begin());
ASSERT_EQ(h_output.size(), solution.size()); ASSERT_EQ(h_output.size(), solution.size());
for (size_t i = 0; i < solution.size(); ++i) { for (size_t i = 0; i < solution.size(); ++i) {
@ -230,21 +230,21 @@ void CompareFeatureList(common::Span<int32_t> s_output, std::vector<int32_t> sol
} // anonymous namespace } // anonymous namespace
TEST(FeatureInteractionConstraint, Query) { TEST(GPUFeatureInteractionConstraint, Query) {
{ {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; bst_feature_t constexpr kFeatures = 6;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
std::vector<int32_t> h_input_feature_list {0, 1, 2, 3, 4, 5}; std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
dh::device_vector<int32_t> d_input_feature_list (h_input_feature_list); dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
common::Span<int32_t> s_input_feature_list = dh::ToSpan(d_input_feature_list); common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
auto s_output = constraints.Query(s_input_feature_list, 0); auto s_output = constraints.Query(s_input_feature_list, 0);
CompareFeatureList(s_output, h_input_feature_list); CompareFeatureList(s_output, h_input_feature_list);
} }
{ {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; bst_feature_t constexpr kFeatures = 6;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
constraints.Split(/*node_id=*/0, /*feature_id=*/1, /*left_id=*/1, /*right_id=*/2); constraints.Split(/*node_id=*/0, /*feature_id=*/1, /*left_id=*/1, /*right_id=*/2);
constraints.Split(/*node_id=*/1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4); constraints.Split(/*node_id=*/1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
@ -264,9 +264,9 @@ TEST(FeatureInteractionConstraint, Query) {
* *
*/ */
std::vector<int32_t> h_input_feature_list {0, 1, 2, 3, 4, 5}; std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
dh::device_vector<int32_t> d_input_feature_list (h_input_feature_list); dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
common::Span<int32_t> s_input_feature_list = dh::ToSpan(d_input_feature_list); common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
auto s_output = constraints.Query(s_input_feature_list, 1); auto s_output = constraints.Query(s_input_feature_list, 1);
CompareFeatureList(s_output, {0, 1, 2}); CompareFeatureList(s_output, {0, 1, 2});
@ -285,16 +285,16 @@ TEST(FeatureInteractionConstraint, Query) {
// Test shared feature // Test shared feature
{ {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; bst_feature_t constexpr kFeatures = 6;
std::string const constraints_str = R"constraint([[1, 2], [2, 3, 4]])constraint"; std::string const constraints_str = R"constraint([[1, 2], [2, 3, 4]])constraint";
param.interaction_constraints = constraints_str; param.interaction_constraints = constraints_str;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2); constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
std::vector<int32_t> h_input_feature_list {0, 1, 2, 3, 4, 5}; std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
dh::device_vector<int32_t> d_input_feature_list (h_input_feature_list); dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
common::Span<int32_t> s_input_feature_list = dh::ToSpan(d_input_feature_list); common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
auto s_output = constraints.Query(s_input_feature_list, 1); auto s_output = constraints.Query(s_input_feature_list, 1);
CompareFeatureList(s_output, {1, 2, 3, 4}); CompareFeatureList(s_output, {1, 2, 3, 4});
@ -303,13 +303,13 @@ TEST(FeatureInteractionConstraint, Query) {
// Test choosing free feature in root // Test choosing free feature in root
{ {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
int32_t constexpr kFeatures = 6; bst_feature_t constexpr kFeatures = 6;
std::string const constraints_str = R"constraint([[0, 1]])constraint"; std::string const constraints_str = R"constraint([[0, 1]])constraint";
param.interaction_constraints = constraints_str; param.interaction_constraints = constraints_str;
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
std::vector<int32_t> h_input_feature_list {0, 1, 2, 3, 4, 5}; std::vector<bst_feature_t> h_input_feature_list {0, 1, 2, 3, 4, 5};
dh::device_vector<int32_t> d_input_feature_list (h_input_feature_list); dh::device_vector<bst_feature_t> d_input_feature_list (h_input_feature_list);
common::Span<int32_t> s_input_feature_list = dh::ToSpan(d_input_feature_list); common::Span<bst_feature_t> s_input_feature_list = dh::ToSpan(d_input_feature_list);
constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2); constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2);
auto s_output = constraints.Query(s_input_feature_list, 1); auto s_output = constraints.Query(s_input_feature_list, 1);
CompareFeatureList(s_output, {2}); CompareFeatureList(s_output, {2});

View File

@ -0,0 +1,69 @@
#include <gtest/gtest.h>
#include <xgboost/tree_model.h>
#include <xgboost/tree_updater.h>
#include "../helpers.h"
namespace xgboost {
namespace tree {
TEST(GrowHistMaker, InteractionConstraint) {
size_t constexpr kRows = 32;
size_t constexpr kCols = 16;
GenericParameter param;
param.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
auto pp_dmat = CreateDMatrix(kRows, kCols, 0.6, 3);
auto p_dmat = *pp_dmat;
HostDeviceVector<GradientPair> gradients (kRows);
std::vector<GradientPair>& h_gradients = gradients.HostVector();
xgboost::SimpleLCG gen;
xgboost::SimpleRealUniformDistribution<bst_float> dist(0.0f, 1.0f);
for (size_t i = 0; i < kRows; ++i) {
bst_float grad = dist(&gen);
bst_float hess = dist(&gen);
h_gradients[i] = GradientPair(grad, hess);
}
{
// With constraints
RegTree tree;
tree.param.num_feature = kCols;
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", &param) };
updater->Configure(Args{
{"interaction_constraints", "[[0, 1]]"},
{"num_feature", std::to_string(kCols)}});
updater->Update(&gradients, p_dmat.get(), {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 4);
ASSERT_EQ(tree[0].SplitIndex(), 1);
ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0);
ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0);
}
{
// Without constraints
RegTree tree;
tree.param.num_feature = kCols;
std::unique_ptr<TreeUpdater> updater { TreeUpdater::Create("grow_histmaker", &param) };
updater->Configure(Args{{"num_feature", std::to_string(kCols)}});
updater->Update(&gradients, p_dmat.get(), {&tree});
ASSERT_EQ(tree.NumExtraNodes(), 10);
ASSERT_EQ(tree[0].SplitIndex(), 1);
ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0);
ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0);
}
delete pp_dmat;
}
} // namespace tree
} // namespace xgboost

View File

@ -25,8 +25,9 @@ class QuantileHistMock : public QuantileHistMaker {
BuilderMock(const TrainParam& param, BuilderMock(const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner, std::unique_ptr<TreeUpdater> pruner,
std::unique_ptr<SplitEvaluator> spliteval) std::unique_ptr<SplitEvaluator> spliteval,
: RealImpl(param, std::move(pruner), std::move(spliteval)) {} FeatureInteractionConstraintHost int_constraint)
: RealImpl(param, std::move(pruner), std::move(spliteval), std::move(int_constraint)) {}
public: public:
void TestInitData(const GHistIndexMatrix& gmat, void TestInitData(const GHistIndexMatrix& gmat,
@ -238,7 +239,8 @@ class QuantileHistMock : public QuantileHistMaker {
new BuilderMock( new BuilderMock(
param_, param_,
std::move(pruner_), std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()))); std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_));
dmat_ = CreateDMatrix(kNRows, kNCols, 0.8, 3); dmat_ = CreateDMatrix(kNRows, kNCols, 0.8, 3);
} }
~QuantileHistMock() override { delete dmat_; } ~QuantileHistMock() override { delete dmat_; }

View File

@ -1,57 +0,0 @@
#include <gtest/gtest.h>
#include <xgboost/logging.h>
#include <memory>
#include "../../../src/tree/split_evaluator.h"
namespace xgboost {
namespace tree {
TEST(SplitEvaluator, Interaction) {
std::string constraints_str = R"interaction([[0, 1], [1, 2, 3]])interaction";
std::vector<std::pair<std::string, std::string>> args{
{"interaction_constraints", constraints_str},
{"num_feature", "8"}};
{
std::unique_ptr<SplitEvaluator> eval{
SplitEvaluator::Create("elastic_net,interaction")};
eval->Init(args);
eval->AddSplit(0, 1, 2, /*feature_id=*/4, 0, 0);
eval->AddSplit(2, 3, 4, /*feature_id=*/5, 0, 0);
ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/0));
ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/1));
ASSERT_TRUE(eval->CheckFeatureConstraint(2, /*feature_id=*/4));
ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/5));
std::vector<int32_t> accepted_features; // for node 3
for (int32_t f = 0; f < 8; ++f) {
if (eval->CheckFeatureConstraint(3, f)) {
accepted_features.emplace_back(f);
}
}
std::vector<int32_t> solutions{4, 5};
ASSERT_EQ(accepted_features.size(), solutions.size());
for (size_t f = 0; f < accepted_features.size(); ++f) {
ASSERT_EQ(accepted_features[f], solutions[f]);
}
}
{
std::unique_ptr<SplitEvaluator> eval{
SplitEvaluator::Create("elastic_net,interaction")};
eval->Init(args);
eval->AddSplit(/*node_id=*/0, /*left_id=*/1, /*right_id=*/2, /*feature_id=*/4, 0, 0);
std::vector<int32_t> accepted_features; // for node 1
for (int32_t f = 0; f < 8; ++f) {
if (eval->CheckFeatureConstraint(1, f)) {
accepted_features.emplace_back(f);
}
}
ASSERT_EQ(accepted_features.size(), 1);
ASSERT_EQ(accepted_features[0], 4);
}
}
} // namespace tree
} // namespace xgboost

View File

@ -11,7 +11,7 @@ class TestGPUInteractionConstraints(unittest.TestCase):
cputest = test_ic.TestInteractionConstraints() cputest = test_ic.TestInteractionConstraints()
def test_interaction_constraints(self): def test_interaction_constraints(self):
self.cputest.test_interaction_constraints(tree_method='gpu_hist') self.cputest.run_interaction_constraints(tree_method='gpu_hist')
def test_training_accuracy(self): def test_training_accuracy(self):
self.cputest.test_training_accuracy(tree_method='gpu_hist') self.cputest.training_accuracy(tree_method='gpu_hist')

View File

@ -10,7 +10,7 @@ rng = np.random.RandomState(1994)
class TestInteractionConstraints(unittest.TestCase): class TestInteractionConstraints(unittest.TestCase):
def test_interaction_constraints(self, tree_method='hist'): def run_interaction_constraints(self, tree_method):
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x3 = np.random.choice([1, 2, 3], size=1000, replace=True) x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
@ -25,8 +25,7 @@ class TestInteractionConstraints(unittest.TestCase):
'eta': 0.1, 'eta': 0.1,
'nthread': 2, 'nthread': 2,
'interaction_constraints': '[[0, 1]]', 'interaction_constraints': '[[0, 1]]',
'tree_method': tree_method, 'tree_method': tree_method
'verbosity': 2
} }
num_boost_round = 12 num_boost_round = 12
# Fit a model that only allows interaction between x1 and x2 # Fit a model that only allows interaction between x1 and x2
@ -50,8 +49,17 @@ class TestInteractionConstraints(unittest.TestCase):
diff2 = preds[2] - preds[1] diff2 = preds[2] - preds[1]
assert np.all(np.abs(diff2 - diff2[0]) < 1e-4) assert np.all(np.abs(diff2 - diff2[0]) < 1e-4)
def test_exact_interaction_constraints(self):
self.run_interaction_constraints(tree_method='exact')
def test_hist_interaction_constraints(self):
self.run_interaction_constraints(tree_method='hist')
def test_approx_interaction_constraints(self):
self.run_interaction_constraints(tree_method='approx')
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_training_accuracy(self, tree_method='hist'): def training_accuracy(self, tree_method):
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1') dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1')
dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1') dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1')
@ -73,3 +81,12 @@ class TestInteractionConstraints(unittest.TestCase):
bst = xgboost.train(params, dtrain, num_boost_round) bst = xgboost.train(params, dtrain, num_boost_round)
pred_dtest = (bst.predict(dtest) < 0.5) pred_dtest = (bst.predict(dtest) < 0.5)
assert accuracy_score(dtest.get_label(), pred_dtest) < 0.1 assert accuracy_score(dtest.get_label(), pred_dtest) < 0.1
def test_hist_training_accuracy(self):
self.training_accuracy(tree_method='hist')
def test_exact_training_accuracy(self):
self.training_accuracy(tree_method='exact')
def test_approx_training_accuracy(self):
self.training_accuracy(tree_method='approx')