From ae05948e32214740c48ec5bb4ef3fa0785842226 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 19 Jun 2019 18:11:02 +0800 Subject: [PATCH] Feature interaction for GPU Hist. (#4534) * GPU hist Interaction Constraints. * Duplicate related parameters. * Add tests for CPU interaction constraint. * Add better error reporting. * Thorough tests. --- include/xgboost/learner.h | 1 + src/common/span.h | 27 +- src/learner.cc | 1 + src/tree/constraints.cu | 347 ++++++++++++++++++ src/tree/constraints.cuh | 248 +++++++++++++ src/tree/param.h | 14 +- src/tree/split_evaluator.cc | 31 +- src/tree/updater_gpu_hist.cu | 50 ++- tests/cpp/tree/test_bitfield.cu | 60 +++ tests/cpp/tree/test_constraints.cu | 321 ++++++++++++++++ tests/cpp/tree/test_gpu_hist.cu | 67 ++-- tests/cpp/tree/test_split_evaluator.cc | 57 +++ .../test_gpu_interaction_constraints.py | 17 + tests/python/test_interaction_constraints.py | 36 +- 14 files changed, 1201 insertions(+), 76 deletions(-) create mode 100644 src/tree/constraints.cu create mode 100644 src/tree/constraints.cuh create mode 100644 tests/cpp/tree/test_bitfield.cu create mode 100644 tests/cpp/tree/test_constraints.cu create mode 100644 tests/cpp/tree/test_split_evaluator.cc create mode 100644 tests/python-gpu/test_gpu_interaction_constraints.py diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index bb4dbefc8..909dfd1c0 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include diff --git a/src/common/span.h b/src/common/span.h index cd28ca27e..f33c2eb89 100644 --- a/src/common/span.h +++ b/src/common/span.h @@ -70,16 +70,16 @@ namespace common { // Usual logging facility is not available inside device code. // TODO(trivialfis): Make dmlc check more generic. // assert is not supported in mac as of CUDA 10.0 -#define KERNEL_CHECK(cond) \ - do { \ - if (!(cond)) { \ - printf("\nKernel error:\n" \ - "In: %s, \tline: %d\n" \ - "\t%s\n\tExpecting: %s\n", \ - __FILE__, __LINE__, __PRETTY_FUNCTION__, # cond); \ - asm("trap;"); \ - } \ - } while (0); \ +#define KERNEL_CHECK(cond) \ + do { \ + if (!(cond)) { \ + printf("\nKernel error:\n" \ + "In: %s, \tline: %d\n" \ + "\t%s\n\tExpecting: %s\n", \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #cond); \ + asm("trap;"); \ + } \ + } while (0); #ifdef __CUDA_ARCH__ #define SPAN_CHECK KERNEL_CHECK @@ -140,10 +140,13 @@ class SpanIterator { SPAN_CHECK(index_ < span_->size()); return *(span_->data() + index_); } + XGBOOST_DEVICE reference operator[](difference_type n) const { + return *(*this + n); + } XGBOOST_DEVICE pointer operator->() const { SPAN_CHECK(index_ != span_->size()); - return span_->data() + index_; + return span_->data() + index_; } XGBOOST_DEVICE SpanIterator& operator++() { @@ -490,7 +493,7 @@ class Span { return data()[_idx]; } - XGBOOST_DEVICE constexpr reference operator()(index_type _idx) const { + XGBOOST_DEVICE reference operator()(index_type _idx) const { return this->operator[](_idx); } diff --git a/src/learner.cc b/src/learner.cc index 8c84b2434..cc5763526 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -6,6 +6,7 @@ */ #include #include +#include #include #include #include diff --git a/src/tree/constraints.cu b/src/tree/constraints.cu new file mode 100644 index 000000000..97b1449cb --- /dev/null +++ b/src/tree/constraints.cu @@ -0,0 +1,347 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include "constraints.cuh" +#include "param.h" +#include "../common/span.h" +#include "../common/device_helpers.cuh" + + +namespace xgboost { + +BitField::value_type constexpr BitField::kValueSize; +BitField::value_type constexpr BitField::kOne; + +size_t FeatureInteractionConstraint::Features() const { + return d_sets_ptr_.size() - 1; +} + +void FeatureInteractionConstraint::Configure( + tree::TrainParam const& param, int32_t const n_features) { + has_constraint_ = true; + if (param.interaction_constraints.length() == 0) { + has_constraint_ = false; + return; + } + // --- Parse interaction constraints + std::istringstream iss(param.interaction_constraints); + dmlc::JSONReader reader(&iss); + // Interaction constraints parsed from string parameter. After + // parsing, this looks like {{0, 1, 2}, {2, 3 ,4}}. + std::vector> h_feature_constraints; + try { + reader.Read(&h_feature_constraints); + } catch (dmlc::Error const& e) { + LOG(FATAL) << "Failed to parse feature interaction constraint:\n" + << param.interaction_constraints << "\n" + << "With error:\n" << e.what(); + } + n_sets_ = h_feature_constraints.size(); + + size_t const n_feat_storage = BitField::ComputeStorageSize(n_features); + if (n_feat_storage == 0 && n_features != 0) { + LOG(FATAL) << "Wrong storage size, n_features: " << n_features; + } + + // --- Initialize allowed features attached to nodes. + if (param.max_depth == 0 && param.max_leaves == 0) { + LOG(FATAL) << "Max leaves and max depth cannot both be unconstrained for gpu_hist."; + } + int32_t n_nodes {0}; + if (param.max_depth != 0) { + n_nodes = std::pow(2, param.max_depth + 1); + } else { + n_nodes = param.max_leaves * 2 - 1; + } + CHECK_NE(n_nodes, 0); + node_constraints_.resize(n_nodes); + node_constraints_storage_.resize(n_nodes); + for (auto& n : node_constraints_storage_) { + n.resize(BitField::ComputeStorageSize(n_features)); + } + for (size_t i = 0; i < node_constraints_storage_.size(); ++i) { + auto span = dh::ToSpan(node_constraints_storage_[i]); + node_constraints_[i] = BitField(span); + } + s_node_constraints_ = common::Span(node_constraints_.data(), + node_constraints_.size()); + + // Represent constraints as CSR format, flatten is the value vector, + // ptr is row_ptr vector in CSR. + std::vector h_feature_constraints_flatten; + for (auto const& constraints : h_feature_constraints) { + for (int32_t c : constraints) { + h_feature_constraints_flatten.emplace_back(c); + } + } + std::vector h_feature_constraints_ptr; + size_t n_features_in_constraints = 0; + h_feature_constraints_ptr.emplace_back(n_features_in_constraints); + for (auto const& v : h_feature_constraints) { + n_features_in_constraints += v.size(); + h_feature_constraints_ptr.emplace_back(n_features_in_constraints); + } + // Copy the CSR to device. + d_fconstraints_.resize(h_feature_constraints_flatten.size()); + thrust::copy(h_feature_constraints_flatten.cbegin(), h_feature_constraints_flatten.cend(), + d_fconstraints_.begin()); + s_fconstraints_ = dh::ToSpan(d_fconstraints_); + d_fconstraints_ptr_.resize(h_feature_constraints_ptr.size()); + thrust::copy(h_feature_constraints_ptr.cbegin(), h_feature_constraints_ptr.cend(), + d_fconstraints_ptr_.begin()); + s_fconstraints_ptr_ = dh::ToSpan(d_fconstraints_ptr_); + + // --- Compute interaction sets attached to each feature. + // Use a set to eliminate duplicated entries. + std::vector > h_features_set(n_features); + int32_t cid = 0; + for (auto const& constraints : h_feature_constraints) { + for (auto const& feat : constraints) { + h_features_set.at(feat).insert(cid); + } + cid++; + } + // Compute device sets. + std::vector h_sets; + int32_t ptr = 0; + std::vector h_sets_ptr {ptr}; + for (auto const& feature : h_features_set) { + for (auto constraint_id : feature) { + h_sets.emplace_back(constraint_id); + } + // empty set is well defined here. + ptr += feature.size(); + h_sets_ptr.emplace_back(ptr); + } + d_sets_ = h_sets; + d_sets_ptr_ = h_sets_ptr; + s_sets_ = dh::ToSpan(d_sets_); + s_sets_ptr_ = dh::ToSpan(d_sets_ptr_); + + d_feature_buffer_storage_.resize(BitField::ComputeStorageSize(n_features)); + feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_); + + // --- Initialize result buffers. + output_buffer_bits_storage_.resize(n_features); + output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_)); + input_buffer_bits_storage_.resize(n_features); + input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_)); + result_buffer_.resize(n_features); + s_result_buffer_ = dh::ToSpan(result_buffer_); +} + +FeatureInteractionConstraint::FeatureInteractionConstraint( + tree::TrainParam const& param, int32_t const n_features) : + has_constraint_{true}, n_sets_{0} { + this->Configure(param, n_features); +} + +void FeatureInteractionConstraint::Reset() { + for (auto& node : node_constraints_storage_) { + thrust::fill(node.begin(), node.end(), 0); + } +} + +__global__ void ClearBuffersKernel( + BitField result_buffer_self, BitField result_buffer_input, BitField feature_buffer) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < result_buffer_self.Size()) { + result_buffer_self.Clear(tid); + } + if (tid < result_buffer_input.Size()) { + result_buffer_input.Clear(tid); + } +} + +void FeatureInteractionConstraint::ClearBuffers() { + CHECK_EQ(output_buffer_bits_.Size(), input_buffer_bits_.Size()); + CHECK_LE(feature_buffer_.Size(), output_buffer_bits_.Size()); + int constexpr kBlockThreads = 256; + const int n_grids = static_cast( + dh::DivRoundUp(input_buffer_bits_.Size(), kBlockThreads)); + ClearBuffersKernel<<>>( + output_buffer_bits_, input_buffer_bits_, feature_buffer_); +} + +common::Span FeatureInteractionConstraint::QueryNode(int32_t node_id) { + if (!has_constraint_) { return {}; } + CHECK_LT(node_id, s_node_constraints_.size()); + + ClearBuffers(); + + thrust::counting_iterator begin(0); + thrust::counting_iterator end(result_buffer_.size()); + auto p_result_buffer = result_buffer_.data(); + BitField node_constraints = s_node_constraints_[node_id]; + + thrust::device_ptr const out_end = thrust::copy_if( + thrust::device, + begin, end, + p_result_buffer, + [=]__device__(int32_t pos) { + bool res = node_constraints.Check(pos); + return res; + }); + size_t const n_available = std::distance(result_buffer_.data(), out_end); + + return {s_result_buffer_.data(), s_result_buffer_.data() + n_available}; +} + +__global__ void QueryFeatureListKernel(common::Span feature_list_input, + common::Span node_feature_list, + BitField result_buffer_input, + BitField result_buffer_output) { + uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < feature_list_input.size()) { + result_buffer_input.Set(feature_list_input[tid]); + } + + if (tid < node_feature_list.size()) { + result_buffer_output.Set(node_feature_list[tid]); + } + result_buffer_output &= result_buffer_input; +} + +common::Span FeatureInteractionConstraint::Query( + common::Span feature_list, int32_t nid) { + if (!has_constraint_ || nid == 0) { + return feature_list; + } + auto selected = this->QueryNode(nid); + CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); + int constexpr kBlockThreads = 256; + const int n_grids = static_cast( + dh::DivRoundUp(output_buffer_bits_.Size(), kBlockThreads)); + + QueryFeatureListKernel<<>> + (feature_list, + selected, + input_buffer_bits_, + output_buffer_bits_); + + thrust::counting_iterator begin(0); + thrust::counting_iterator end(result_buffer_.size()); + + BitField local_result_buffer = output_buffer_bits_; + + thrust::device_ptr const out_end = thrust::copy_if( + thrust::device, + begin, end, + result_buffer_.data(), + [=]__device__(int32_t pos) { + bool res = local_result_buffer.Check(pos); + return res; + }); + size_t const n_available = std::distance(result_buffer_.data(), out_end); + + common::Span result = + {s_result_buffer_.data(), s_result_buffer_.data() + n_available}; + return result; +} + +// Find interaction sets for each feature, then store all features in +// those sets in a buffer. +__global__ void RestoreFeatureListFromSetsKernel( + BitField feature_buffer, + + int32_t fid, + common::Span feature_interactions, + common::Span feature_interactions_ptr, // of size n interaction set + 1 + + common::Span interactions_list, + common::Span interactions_list_ptr) { + auto const tid_x = threadIdx.x + blockIdx.x * blockDim.x; + auto const tid_y = threadIdx.y + blockIdx.y * blockDim.y; + // painful mapping: fid -> sets related to it -> features related to sets. + auto const beg = interactions_list_ptr[fid]; + auto const end = interactions_list_ptr[fid+1]; + auto const n_sets = end - beg; + if (tid_x < n_sets) { + auto const set_id_pos = beg + tid_x; + auto const set_id = interactions_list[set_id_pos]; + auto const set_beg = feature_interactions_ptr[set_id]; + auto const set_end = feature_interactions_ptr[set_id + 1]; + auto const feature_pos = set_beg + tid_y; + if (feature_pos < set_end) { + feature_buffer.Set(feature_interactions[feature_pos]); + } + } +} + +__global__ void InteractionConstraintSplitKernel(BitField feature, + int32_t feature_id, + BitField node, + BitField left, + BitField right) { + auto tid = threadIdx.x + blockDim.x * blockIdx.x; + if (tid > node.Size()) { + return; + } + // enable constraints from feature + node |= feature; + // clear the buffer after use + if (tid < feature.Size()) { + feature.Clear(tid); + } + + // enable constraints from parent + left |= node; + right |= node; + + if (tid == feature_id) { + // enable the split feature, set all of them at last instead of + // setting it for parent to avoid race. + node.Set(feature_id); + left.Set(feature_id); + right.Set(feature_id); + } +} + +void FeatureInteractionConstraint::Split( + int32_t node_id, int32_t feature_id, int32_t left_id, int32_t right_id) { + if (!has_constraint_) { return; } + CHECK_NE(node_id, left_id) + << " Split node: " << node_id << " and its left child: " + << left_id << " cannot be the same."; + CHECK_NE(node_id, right_id) + << " Split node: " << node_id << " and its left child: " + << right_id << " cannot be the same."; + CHECK_LT(right_id, s_node_constraints_.size()); + CHECK_NE(s_node_constraints_.size(), 0); + + BitField node = s_node_constraints_[node_id]; + BitField left = s_node_constraints_[left_id]; + BitField right = s_node_constraints_[right_id]; + + dim3 const block3(16, 64, 1); + dim3 const grid3(dh::DivRoundUp(n_sets_, 16), + dh::DivRoundUp(s_fconstraints_.size(), 64)); + RestoreFeatureListFromSetsKernel<<>> + (feature_buffer_, + feature_id, + s_fconstraints_, + s_fconstraints_ptr_, + s_sets_, + s_sets_ptr_); + + int constexpr kBlockThreads = 256; + const int n_grids = static_cast(dh::DivRoundUp(node.Size(), kBlockThreads)); + InteractionConstraintSplitKernel<<>> + (feature_buffer_, + feature_id, + node, left, right); +} + +} // namespace xgboost diff --git a/src/tree/constraints.cuh b/src/tree/constraints.cuh new file mode 100644 index 000000000..eebbea5da --- /dev/null +++ b/src/tree/constraints.cuh @@ -0,0 +1,248 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#ifndef XGBOOST_TREE_CONSTRAINTS_H_ +#define XGBOOST_TREE_CONSTRAINTS_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "param.h" +#include "../common/span.h" +#include "../common/device_helpers.cuh" + +#include + +namespace xgboost { + +__forceinline__ __device__ unsigned long long AtomicOr(unsigned long long* address, + unsigned long long val) { + unsigned long long int old = *address, assumed; // NOLINT + do { + assumed = old; + old = atomicCAS(address, assumed, val | assumed); + } while (assumed != old); + + return old; +} + +__forceinline__ __device__ unsigned long long AtomicAnd(unsigned long long* address, + unsigned long long val) { + unsigned long long int old = *address, assumed; // NOLINT + do { + assumed = old; + old = atomicCAS(address, assumed, val & assumed); + } while (assumed != old); + + return old; +} + +/*! + * \brief A non-owning type with auxiliary methods defined for manipulating bits. + */ +struct BitField { + using value_type = uint64_t; + + static value_type constexpr kValueSize = sizeof(value_type) * 8; + static value_type constexpr kOne = 1UL; // force uint64_t + static_assert(kValueSize == 64, "uint64_t should be of 64 bits."); + + struct Pos { + value_type int_pos {0}; + value_type bit_pos {0}; + }; + + common::Span bits_; + + public: + BitField() = default; + XGBOOST_DEVICE BitField(common::Span bits) : bits_{bits} {} + XGBOOST_DEVICE BitField(BitField const& other) : bits_{other.bits_} {} + + static size_t ComputeStorageSize(size_t size) { + auto pos = ToBitPos(size); + if (size < kValueSize) { + return 1; + } + + if (pos.bit_pos != 0) { + return pos.int_pos + 2; + } else { + return pos.int_pos + 1; + } + } + XGBOOST_DEVICE static Pos ToBitPos(value_type pos) { + Pos pos_v; + if (pos == 0) { + return pos_v; + } + pos_v.int_pos = pos / kValueSize; + pos_v.bit_pos = pos % kValueSize; + return pos_v; + } + + __device__ BitField& operator|=(BitField const& rhs) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + size_t min_size = min(bits_.size(), rhs.bits_.size()); + if (tid < min_size) { + bits_[tid] |= rhs.bits_[tid]; + } + return *this; + } + __device__ BitField& operator&=(BitField const& rhs) { + size_t min_size = min(bits_.size(), rhs.bits_.size()); + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < min_size) { + bits_[tid] &= rhs.bits_[tid]; + } + return *this; + } + + XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); } + + __device__ void Set(value_type pos) { + Pos pos_v = ToBitPos(pos); + value_type& value = bits_[pos_v.int_pos]; + value_type set_bit = kOne << (kValueSize - pos_v.bit_pos - kOne); + static_assert(sizeof(unsigned long long int) == sizeof(value_type), ""); + AtomicOr(reinterpret_cast(&value), set_bit); + } + __device__ void Clear(value_type pos) { + Pos pos_v = ToBitPos(pos); + value_type& value = bits_[pos_v.int_pos]; + value_type clear_bit = ~(kOne << (kValueSize - pos_v.bit_pos - kOne)); + static_assert(sizeof(unsigned long long int) == sizeof(value_type), ""); + AtomicAnd(reinterpret_cast(&value), clear_bit); + } + + XGBOOST_DEVICE bool Check(Pos pos_v) const { + value_type value = bits_[pos_v.int_pos]; + value_type const test_bit = kOne << (kValueSize - pos_v.bit_pos - kOne); + value_type result = test_bit & value; + return static_cast(result); + } + XGBOOST_DEVICE bool Check(value_type pos) const { + Pos pos_v = ToBitPos(pos); + return Check(pos_v); + } + + friend std::ostream& operator<<(std::ostream& os, BitField field) { + os << "Bits " << "storage size: " << field.bits_.size() << "\n"; + for (size_t i = 0; i < field.bits_.size(); ++i) { + std::bitset set(field.bits_[i]); + os << set << "\n"; + } + return os; + } +}; + +inline void PrintDeviceBits(std::string name, BitField field) { + std::cout << "Bits: " << name << std::endl; + std::vector h_field_bits(field.bits_.size()); + thrust::copy(thrust::device_ptr(field.bits_.data()), + thrust::device_ptr(field.bits_.data() + field.bits_.size()), + h_field_bits.data()); + BitField h_field; + h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()}; + std::cout << h_field; +} + +inline void PrintDeviceStorage(std::string name, common::Span list) { + std::cout << name << std::endl; + std::vector h_list(list.size()); + thrust::copy(thrust::device_ptr(list.data()), + thrust::device_ptr(list.data() + list.size()), + h_list.data()); + for (auto v : h_list) { + std::cout << v << ", "; + } + std::cout << std::endl; +} + +// Feature interaction constraints built for GPU Hist updater. +struct FeatureInteractionConstraint { + protected: + // Whether interaction constraint is used. + bool has_constraint_; + // n interaction sets. + int32_t n_sets_; + + // The parsed feature interaction constraints as CSR. + dh::device_vector d_fconstraints_; + common::Span s_fconstraints_; + dh::device_vector d_fconstraints_ptr_; + common::Span s_fconstraints_ptr_; + /* Interaction sets for each feature as CSR. For an input like: + * [[0, 1], [1, 2]], this will have values: + * + * fid: |0 | 1 | 2| + * sets a feature belongs to(d_sets_): |0 |0, 1| 1| + * + * d_sets_ptr_: |0, 1, 3, 4| + */ + dh::device_vector d_sets_; + common::Span s_sets_; + dh::device_vector d_sets_ptr_; + common::Span s_sets_ptr_; + + // Allowed features attached to each node, have n_nodes bitfields, + // each of size n_features. + std::vector> node_constraints_storage_; + std::vector node_constraints_; + common::Span s_node_constraints_; + + // buffer storing return feature list from Query, of size n_features. + dh::device_vector result_buffer_; + common::Span s_result_buffer_; + + // Temp buffers, one bit for each possible feature. + dh::device_vector output_buffer_bits_storage_; + BitField output_buffer_bits_; + dh::device_vector input_buffer_bits_storage_; + BitField input_buffer_bits_; + /* + * Combined features from all interaction sets that one feature belongs to. + * For an input with [[0, 1], [1, 2]], the feature 1 belongs to sets {0, 1} + */ + dh::device_vector d_feature_buffer_storage_; + BitField feature_buffer_; // of Size n features. + + // Clear out all temp buffers except for `feature_buffer_', which is + // handled in `Split'. + void ClearBuffers(); + + public: + size_t Features() const; + FeatureInteractionConstraint() = default; + void Configure(tree::TrainParam const& param, int32_t const n_features); + FeatureInteractionConstraint(tree::TrainParam const& param, int32_t const n_features); + FeatureInteractionConstraint(FeatureInteractionConstraint const& that) = default; + FeatureInteractionConstraint(FeatureInteractionConstraint&& that) = default; + /*! \brief Reset before constructing a new tree. */ + void Reset(); + /*! \brief Return a list of features given node id */ + common::Span QueryNode(int32_t nid); + /*! + * \brief Return a list of selected features from given feature_list and node id. + * + * \param feature_list A list of features + * \param nid node id + * + * \return A list of features picked from `feature_list' that conform to constraints in + * node. + */ + common::Span Query(common::Span feature_list, int32_t nid); + /*! \brief Apply split for node_id. */ + void Split(int32_t node_id, int32_t feature_id, int32_t left_id, int32_t right_id); +}; + +} // namespace xgboost +#endif // XGBOOST_TREE_CONSTRAINTS_H_ diff --git a/src/tree/param.h b/src/tree/param.h index d0d49a403..b7594cbec 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -70,8 +70,13 @@ struct TrainParam : public dmlc::Parameter { bool cache_opt; // whether refresh updater needs to update the leaf values bool refresh_leaf; - // auxiliary data structure + + // FIXME(trivialfis): Following constraints are used by gpu + // algorithm, duplicated with those defined split evaluator due to + // their different code paths. std::vector monotone_constraints; + std::string interaction_constraints; + // the criteria to use for ranking splits std::string split_evaluator; @@ -187,6 +192,13 @@ struct TrainParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(monotone_constraints) .set_default(std::vector()) .describe("Constraint of variable monotonicity"); + DMLC_DECLARE_FIELD(interaction_constraints) + .set_default("") + .describe("Constraints for interaction representing permitted interactions." + "The constraints must be specified in the form of a nest list," + "e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of" + "indices of features that are allowed to interact with each other." + "See tutorial for more information"); DMLC_DECLARE_FIELD(split_evaluator) .set_default("elastic_net,monotonic,interaction") .describe("The criteria to use for ranking splits"); diff --git a/src/tree/split_evaluator.cc b/src/tree/split_evaluator.cc index 55d2b99ff..5c43567de 100644 --- a/src/tree/split_evaluator.cc +++ b/src/tree/split_evaluator.cc @@ -6,6 +6,7 @@ #include "split_evaluator.h" #include #include +#include #include #include #include @@ -384,17 +385,23 @@ class InteractionConstraint final : public SplitEvaluator { // Read std::vector> first and then // convert to std::vector> std::vector> tmp; - reader.Read(&tmp); + try { + reader.Read(&tmp); + } catch (dmlc::Error const& e) { + LOG(FATAL) << "Failed to parse feature interaction constraint:\n" + << params_.interaction_constraints << "\n" + << "With error:\n" << e.what(); + } for (const auto& e : tmp) { interaction_constraints_.emplace_back(e.begin(), e.end()); } // Initialise interaction constraints record with all variables permitted for the first node - int_cont_.clear(); - int_cont_.resize(1, std::unordered_set()); - int_cont_[0].reserve(params_.num_feature); + node_constraints_.clear(); + node_constraints_.resize(1, std::unordered_set()); + node_constraints_[0].reserve(params_.num_feature); for (bst_uint i = 0; i < params_.num_feature; ++i) { - int_cont_[0].insert(i); + node_constraints_[0].insert(i); } // Initialise splits record @@ -463,12 +470,12 @@ class InteractionConstraint final : public SplitEvaluator { splits_[rightid] = feature_splits; // Resize constraints record, initialise all features to be not permitted for new nodes - int_cont_.resize(newsize, std::unordered_set()); + node_constraints_.resize(newsize, std::unordered_set()); // Permit features used in previous splits for (bst_uint fid : feature_splits) { - int_cont_[leftid].insert(fid); - int_cont_[rightid].insert(fid); + node_constraints_[leftid].insert(fid); + node_constraints_[rightid].insert(fid); } // Loop across specified interactions in constraints @@ -486,8 +493,8 @@ class InteractionConstraint final : public SplitEvaluator { // If interaction is still relevant, permit all other features in the interaction if (flag == 1) { for (bst_uint k : constraint) { - int_cont_[leftid].insert(k); - int_cont_[rightid].insert(k); + node_constraints_[leftid].insert(k); + node_constraints_[rightid].insert(k); } } } @@ -506,7 +513,7 @@ class InteractionConstraint final : public SplitEvaluator { std::vector< std::unordered_set > interaction_constraints_; // int_cont_[nid] contains the set of all feature IDs that are allowed to // be used for a split at node nid - std::vector< std::unordered_set > int_cont_; + std::vector< std::unordered_set > node_constraints_; // splits_[nid] contains the set of all feature IDs that have been used for // splits in node nid and its parents std::vector< std::unordered_set > splits_; @@ -516,7 +523,7 @@ class InteractionConstraint final : public SplitEvaluator { inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const { // short-circuit if no constraint is specified return (params_.interaction_constraints.empty() - || int_cont_.at(nodeid).count(featureid) > 0); + || node_constraints_.at(nodeid).count(featureid) > 0); } }; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index f34463d43..9750c6f8c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -24,6 +24,7 @@ #include "../common/span.h" #include "param.h" #include "updater_gpu_common.cuh" +#include "constraints.cuh" namespace xgboost { namespace tree { @@ -318,9 +319,8 @@ __device__ void EvaluateFeature( template __global__ void EvaluateSplitKernel( - common::Span - node_histogram, // histogram for gradients - common::Span feature_set, // Selected features + common::Span node_histogram, // histogram for gradients + common::Span feature_set, // Selected features DeviceNodeStats node, ELLPackMatrix matrix, GPUTrainingParam gpu_param, @@ -354,6 +354,7 @@ __global__ void EvaluateSplitKernel( // One block for each feature. Features are sampled, so fidx != blockIdx.x int fidx = feature_set[blockIdx.x]; + int constraint = d_monotonic_constraints[fidx]; EvaluateFeature( fidx, node_histogram, matrix, &best_split, node, gpu_param, &temp_storage, @@ -714,6 +715,7 @@ struct DeviceShard { common::Monitor monitor; std::vector node_value_constraints; common::ColumnSampler column_sampler; + FeatureInteractionConstraint interaction_constraints; using ExpandQueue = std::priority_queue, @@ -721,7 +723,8 @@ struct DeviceShard { std::unique_ptr qexpand; DeviceShard(int _device_id, int shard_idx, bst_uint row_begin, - bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed) + bst_uint row_end, TrainParam _param, uint32_t column_sampler_seed, + uint32_t n_features) : device_id(_device_id), shard_idx(shard_idx), row_begin_idx(row_begin), @@ -730,7 +733,8 @@ struct DeviceShard { n_bins(0), param(std::move(_param)), prediction_cache_initialised(false), - column_sampler(column_sampler_seed) { + column_sampler(column_sampler_seed), + interaction_constraints(param, n_features) { monitor.Init(std::string("DeviceShard") + std::to_string(device_id)); } @@ -778,6 +782,8 @@ struct DeviceShard { this->column_sampler.Init(num_columns, param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(device_id)); + this->interaction_constraints.Reset(); + thrust::fill( thrust::device_pointer_cast(position.Current()), thrust::device_pointer_cast(position.Current() + position.Size()), 0); @@ -806,7 +812,7 @@ struct DeviceShard { std::vector nidxs, const RegTree& tree, size_t num_columns) { dh::safe_cuda(cudaSetDevice(device_id)); - auto result = pinned_memory.GetSpan(nidxs.size()); + auto result_all = pinned_memory.GetSpan(nidxs.size()); // Work out cub temporary memory requirement GPUTrainingParam gpu_param(param); @@ -840,11 +846,26 @@ struct DeviceShard { auto nidx = nidxs[i]; auto p_feature_set = column_sampler.GetFeatureSet(tree.GetDepth(nidx)); p_feature_set->Shard(GPUSet(device_id, 1)); - auto d_feature_set = p_feature_set->DeviceSpan(device_id); + auto d_sampled_features = p_feature_set->DeviceSpan(device_id); + common::Span d_feature_set = + interaction_constraints.Query(d_sampled_features, nidx); auto d_split_candidates = d_split_candidates_all.subspan(i * num_columns, d_feature_set.size()); + DeviceNodeStats node(node_sum_gradients[nidx], nidx, param); + auto d_result = d_result_all.subspan(i, 1); + if (d_feature_set.size() == 0) { + // Acting as a device side constructor for DeviceSplitCandidate. + // DeviceSplitCandidate::IsValid is false so that ApplySplit can reject this + // candidate. + auto worst_candidate = DeviceSplitCandidate(); + dh::safe_cuda(cudaMemcpyAsync(d_result.data(), &worst_candidate, + sizeof(DeviceSplitCandidate), + cudaMemcpyHostToDevice)); + continue; + } + // One block for each feature int constexpr kBlockThreads = 256; EvaluateSplitKernel @@ -854,7 +875,6 @@ struct DeviceShard { monotone_constraints); // Reduce over features to find best feature - auto d_result = d_result_all.subspan(i, 1); auto d_cub_memory = d_cub_memory_all.subspan(i * cub_memory_size, cub_memory_size); size_t cub_bytes = d_cub_memory.size() * sizeof(DeviceSplitCandidate); @@ -864,11 +884,10 @@ struct DeviceShard { DeviceSplitCandidate(), streams[i]); } - dh::safe_cuda(cudaMemcpy(result.data(), d_result_all.data(), + dh::safe_cuda(cudaMemcpy(result_all.data(), d_result_all.data(), sizeof(DeviceSplitCandidate) * d_result_all.size(), cudaMemcpyDeviceToHost)); - - return std::vector(result.begin(), result.end()); + return std::vector(result_all.begin(), result_all.end()); } void BuildHist(int nidx) { @@ -1137,6 +1156,10 @@ struct DeviceShard { candidate.split.left_sum; node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum; + + interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(), + tree[candidate.nid].LeftChild(), + tree[candidate.nid].RightChild()); } void InitRoot(RegTree* p_tree, HostDeviceVector* gpair_all, @@ -1202,7 +1225,7 @@ struct DeviceShard { int right_child_nidx = tree[candidate.nid].RightChild(); // Only create child entries if needed if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { + num_leaves)) { monitor.StartCuda("UpdatePosition"); this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); monitor.StopCuda("UpdatePosition"); @@ -1487,7 +1510,8 @@ class GPUHistMakerSpecialised { shard = std::unique_ptr>( new DeviceShard(dist_.Devices().DeviceId(idx), idx, start, start + size, param_, - column_sampling_seed)); + column_sampling_seed, + info_->num_col_)); }); monitor_.StartCuda("Quantiles"); diff --git a/tests/cpp/tree/test_bitfield.cu b/tests/cpp/tree/test_bitfield.cu new file mode 100644 index 000000000..aa5d36d49 --- /dev/null +++ b/tests/cpp/tree/test_bitfield.cu @@ -0,0 +1,60 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#include +#include +#include +#include +#include "../../../src/tree/constraints.cuh" +#include "../../../src/common/device_helpers.cuh" + +namespace xgboost { + +__global__ void TestSetKernel(BitField bits) { + auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < bits.Size()) { + bits.Set(tid); + } +} + +TEST(BitField, Set) { + dh::device_vector storage; + uint32_t constexpr kBits = 128; + storage.resize(128); + auto bits = BitField(dh::ToSpan(storage)); + TestSetKernel<<<1, kBits>>>(bits); + + std::vector h_storage(storage.size()); + thrust::copy(storage.begin(), storage.end(), h_storage.begin()); + + BitField outputs { + common::Span{h_storage.data(), + h_storage.data() + h_storage.size()}}; + for (size_t i = 0; i < kBits; ++i) { + ASSERT_TRUE(outputs.Check(i)); + } +} + +__global__ void TestOrKernel(BitField lhs, BitField rhs) { + lhs |= rhs; +} + +TEST(BitField, And) { + uint32_t constexpr kBits = 128; + dh::device_vector lhs_storage(kBits); + dh::device_vector rhs_storage(kBits); + auto lhs = BitField(dh::ToSpan(lhs_storage)); + auto rhs = BitField(dh::ToSpan(rhs_storage)); + thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL); + thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast(0UL)); + TestOrKernel<<<1, kBits>>>(lhs, rhs); + + std::vector h_storage(lhs_storage.size()); + thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin()); + BitField outputs {{h_storage.data(), h_storage.data() + h_storage.size()}}; + for (size_t i = 0; i < kBits; ++i) { + ASSERT_TRUE(outputs.Check(i)); + } +} + +} // namespace xgboost \ No newline at end of file diff --git a/tests/cpp/tree/test_constraints.cu b/tests/cpp/tree/test_constraints.cu new file mode 100644 index 000000000..b2051f0b3 --- /dev/null +++ b/tests/cpp/tree/test_constraints.cu @@ -0,0 +1,321 @@ +/*! + * Copyright 2019 XGBoost contributors + */ +#include +#include +#include +#include +#include +#include +#include +#include "../../../src/tree/constraints.cuh" +#include "../../../src/tree/param.h" +#include "../../../src/common/device_helpers.cuh" + +namespace xgboost { +namespace { + +struct FConstraintWrapper : public FeatureInteractionConstraint { + common::Span GetNodeConstraints() { + return FeatureInteractionConstraint::s_node_constraints_; + } + FConstraintWrapper(tree::TrainParam param, int32_t n_features) : + FeatureInteractionConstraint(param, n_features) {} + + dh::device_vector const& GetDSets() const { + return d_sets_; + } + dh::device_vector const& GetDSetsPtr() const { + return d_sets_ptr_; + } +}; + +std::string GetConstraintsStr() { + std::string const constraints_str = R"constraint([[1, 2], [3, 4, 5]])constraint"; + return constraints_str; +} + +tree::TrainParam GetParameter() { + std::vector> args{ + {"interaction_constraints", GetConstraintsStr()} + }; + tree::TrainParam param; + param.Init(args); + return param; +} + +void CompareBitField(BitField d_field, std::set positions) { + std::vector h_field_storage(d_field.bits_.size()); + thrust::copy(thrust::device_ptr(d_field.bits_.data()), + thrust::device_ptr( + d_field.bits_.data() + d_field.bits_.size()), + h_field_storage.data()); + BitField h_field; + h_field.bits_ = {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()}; + + for (size_t i = 0; i < h_field.Size(); ++i) { + if (positions.find(i) != positions.cend()) { + ASSERT_TRUE(h_field.Check(i)); + } else { + ASSERT_FALSE(h_field.Check(i)); + } + } +} + +} // anonymous namespace + + +TEST(FeatureInteractionConstraint, Init) { + { + int32_t constexpr kFeatures = 6; + tree::TrainParam param = GetParameter(); + FConstraintWrapper constraints(param, kFeatures); + ASSERT_EQ(constraints.Features(), kFeatures); + common::Span s_nodes_constraints = constraints.GetNodeConstraints(); + for (BitField const& d_node : s_nodes_constraints) { + std::vector h_node_storage(d_node.bits_.size()); + thrust::copy(thrust::device_ptr(d_node.bits_.data()), + thrust::device_ptr( + d_node.bits_.data() + d_node.bits_.size()), + h_node_storage.data()); + BitField h_node; + h_node.bits_ = {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()}; + // no feature is attached to node. + for (size_t i = 0; i < h_node.Size(); ++i) { + ASSERT_FALSE(h_node.Check(i)); + } + } + } + + { + // Test one feature in multiple sets + int32_t constexpr kFeatures = 7; + tree::TrainParam param = GetParameter(); + param.interaction_constraints = R"([[0, 1, 3], [3, 5, 6]])"; + FConstraintWrapper constraints(param, kFeatures); + std::vector h_sets {0, 0, 0, 1, 1, 1}; + std::vector h_sets_ptr {0, 1, 2, 2, 4, 4, 5, 6}; + auto d_sets = constraints.GetDSets(); + ASSERT_EQ(h_sets.size(), d_sets.size()); + auto d_sets_ptr = constraints.GetDSetsPtr(); + ASSERT_EQ(h_sets_ptr, d_sets_ptr); + for (size_t i = 0; i < h_sets.size(); ++i) { + ASSERT_EQ(h_sets[i], d_sets[i]); + } + for (size_t i = 0; i < h_sets_ptr.size(); ++i) { + ASSERT_EQ(h_sets_ptr[i], d_sets_ptr[i]); + } + } + + { + // Test having more than 1 BitField::value_type + int32_t constexpr kFeatures = 129; + tree::TrainParam param = GetParameter(); + param.interaction_constraints = R"([[0, 1, 3], [3, 5, 128], [127, 128]])"; + FConstraintWrapper constraints(param, kFeatures); + auto d_sets = constraints.GetDSets(); + auto d_sets_ptr = constraints.GetDSetsPtr(); + auto _128_beg = d_sets_ptr[128]; + auto _128_end = d_sets_ptr[128 + 1]; + ASSERT_EQ(_128_end - _128_beg, 2); + ASSERT_EQ(d_sets[_128_beg], 1); + ASSERT_EQ(d_sets[_128_end-1], 2); + } +} + +TEST(FeatureInteractionConstraint, Split) { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + FConstraintWrapper constraints(param, kFeatures); + + { + BitField d_node[3]; + constraints.Split(0, /*feature_id=*/1, 1, 2); + for (size_t nid = 0; nid < 3; ++nid) { + d_node[nid] = constraints.GetNodeConstraints()[nid]; + ASSERT_EQ(d_node[nid].bits_.size(), 1); + CompareBitField(d_node[nid], {1, 2}); + } + } + + { + BitField d_node[5]; + constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4); + for (auto nid : {1, 3, 4}) { + d_node[nid] = constraints.GetNodeConstraints()[nid]; + CompareBitField(d_node[nid], {0, 1, 2}); + } + for (auto nid : {0, 2}) { + d_node[nid] = constraints.GetNodeConstraints()[nid]; + CompareBitField(d_node[nid], {1, 2}); + } + } +} + +TEST(FeatureInteractionConstraint, QueryNode) { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + FConstraintWrapper constraints(param, kFeatures); + + { + auto span = constraints.QueryNode(0); + ASSERT_EQ(span.size(), 0); + } + + { + constraints.Split(/*node_id=*/ 0, /*feature_id=*/ 1, 1, 2); + auto span = constraints.QueryNode(0); + std::vector h_result (span.size()); + thrust::copy(thrust::device_ptr(span.data()), + thrust::device_ptr(span.data() + span.size()), + h_result.begin()); + ASSERT_EQ(h_result.size(), 2); + ASSERT_EQ(h_result[0], 1); + ASSERT_EQ(h_result[1], 2); + } + + { + constraints.Split(1, /*feature_id=*/0, 3, 4); + auto span = constraints.QueryNode(1); + std::vector h_result (span.size()); + thrust::copy(thrust::device_ptr(span.data()), + thrust::device_ptr(span.data() + span.size()), + h_result.begin()); + ASSERT_EQ(h_result.size(), 3); + ASSERT_EQ(h_result[0], 0); + ASSERT_EQ(h_result[1], 1); + ASSERT_EQ(h_result[2], 2); + + // same as parent + span = constraints.QueryNode(3); + h_result.resize(span.size()); + thrust::copy(thrust::device_ptr(span.data()), + thrust::device_ptr(span.data() + span.size()), + h_result.begin()); + ASSERT_EQ(h_result.size(), 3); + ASSERT_EQ(h_result[0], 0); + ASSERT_EQ(h_result[1], 1); + ASSERT_EQ(h_result[2], 2); + } + + { + tree::TrainParam large_param = GetParameter(); + large_param.interaction_constraints = R"([[1, 139], [244, 0], [139, 221]])"; + FConstraintWrapper large_features(large_param, 256); + large_features.Split(0, 139, 1, 2); + auto span = large_features.QueryNode(0); + std::vector h_result (span.size()); + thrust::copy(thrust::device_ptr(span.data()), + thrust::device_ptr(span.data() + span.size()), + h_result.begin()); + ASSERT_EQ(h_result.size(), 3); + ASSERT_EQ(h_result[0], 1); + ASSERT_EQ(h_result[1], 139); + ASSERT_EQ(h_result[2], 221); + } +} + +namespace { + +void CompareFeatureList(common::Span s_output, std::vector solution) { + std::vector h_output(s_output.size()); + thrust::copy(thrust::device_ptr(s_output.data()), + thrust::device_ptr(s_output.data() + s_output.size()), + h_output.begin()); + ASSERT_EQ(h_output.size(), solution.size()); + for (size_t i = 0; i < solution.size(); ++i) { + ASSERT_EQ(h_output[i], solution[i]); + } +} + +} // anonymous namespace + +TEST(FeatureInteractionConstraint, Query) { + { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + FConstraintWrapper constraints(param, kFeatures); + std::vector h_input_feature_list {0, 1, 2, 3, 4, 5}; + dh::device_vector d_input_feature_list (h_input_feature_list); + common::Span s_input_feature_list = dh::ToSpan(d_input_feature_list); + + auto s_output = constraints.Query(s_input_feature_list, 0); + CompareFeatureList(s_output, h_input_feature_list); + } + { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + FConstraintWrapper constraints(param, kFeatures); + constraints.Split(/*node_id=*/0, /*feature_id=*/1, /*left_id=*/1, /*right_id=*/2); + constraints.Split(/*node_id=*/1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4); + constraints.Split(/*node_id=*/4, /*feature_id=*/3, /*left_id=*/5, /*right_id=*/6); + /* + * (node id) [allowed features] + * + * (0) [1, 2] + * / \ + * {split at 0} \ + * / \ + * (1)[0, 1, 2] (2)[1, 2] + * / \ + * / {split at 3} + * / \ + * (3)[0, 1, 2] (4)[0, 1, 2, 3, 4, 5] + * + */ + + std::vector h_input_feature_list {0, 1, 2, 3, 4, 5}; + dh::device_vector d_input_feature_list (h_input_feature_list); + common::Span s_input_feature_list = dh::ToSpan(d_input_feature_list); + + auto s_output = constraints.Query(s_input_feature_list, 1); + CompareFeatureList(s_output, {0, 1, 2}); + s_output = constraints.Query(s_input_feature_list, 2); + CompareFeatureList(s_output, {1, 2}); + s_output = constraints.Query(s_input_feature_list, 3); + CompareFeatureList(s_output, {0, 1, 2}); + s_output = constraints.Query(s_input_feature_list, 4); + CompareFeatureList(s_output, {0, 1, 2, 3, 4, 5}); + s_output = constraints.Query(s_input_feature_list, 5); + CompareFeatureList(s_output, {0, 1, 2, 3, 4, 5}); + s_output = constraints.Query(s_input_feature_list, 6); + CompareFeatureList(s_output, {0, 1, 2, 3, 4, 5}); + } + + // Test shared feature + { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + std::string const constraints_str = R"constraint([[1, 2], [2, 3, 4]])constraint"; + param.interaction_constraints = constraints_str; + + FConstraintWrapper constraints(param, kFeatures); + constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2); + + std::vector h_input_feature_list {0, 1, 2, 3, 4, 5}; + dh::device_vector d_input_feature_list (h_input_feature_list); + common::Span s_input_feature_list = dh::ToSpan(d_input_feature_list); + + auto s_output = constraints.Query(s_input_feature_list, 1); + CompareFeatureList(s_output, {1, 2, 3, 4}); + } + + // Test choosing free feature in root + { + tree::TrainParam param = GetParameter(); + int32_t constexpr kFeatures = 6; + std::string const constraints_str = R"constraint([[0, 1]])constraint"; + param.interaction_constraints = constraints_str; + FConstraintWrapper constraints(param, kFeatures); + std::vector h_input_feature_list {0, 1, 2, 3, 4, 5}; + dh::device_vector d_input_feature_list (h_input_feature_list); + common::Span s_input_feature_list = dh::ToSpan(d_input_feature_list); + constraints.Split(/*node_id=*/0, /*feature_id=*/2, /*left_id=*/1, /*right_id=*/2); + auto s_output = constraints.Query(s_input_feature_list, 1); + CompareFeatureList(s_output, {2}); + s_output = constraints.Query(s_input_feature_list, 2); + CompareFeatureList(s_output, {2}); + } +} + +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 614f2ff94..0031ac6d2 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -1,7 +1,6 @@ /*! - * Copyright 2017-2018 XGBoost contributors + * Copyright 2017-2019 XGBoost contributors */ - #include #include #include @@ -16,6 +15,7 @@ #include "../../../src/tree/updater_gpu_hist.cu" #include "../../../src/tree/updater_gpu_common.cuh" #include "../../../src/common/common.h" +#include "../../../src/tree/constraints.cuh" namespace xgboost { namespace tree { @@ -91,11 +91,13 @@ void BuildGidx(DeviceShard* shard, int n_rows, int n_cols, TEST(GpuHist, BuildGidxDense) { int constexpr kNRows = 16, kNCols = 8; - TrainParam param; - param.max_depth = 1; - param.max_leaves = 0; - - DeviceShard shard(0, 0, 0, kNRows, param, kNCols); + tree::TrainParam param; + std::vector> args { + {"max_depth", "1"}, + {"max_leaves", "0"}, + }; + param.Init(args); + DeviceShard shard(0, 0, 0, kNRows, param, kNCols, kNCols); BuildGidx(&shard, kNRows, kNCols); std::vector h_gidx_buffer(shard.gidx_buffer.size()); @@ -130,10 +132,14 @@ TEST(GpuHist, BuildGidxDense) { TEST(GpuHist, BuildGidxSparse) { int constexpr kNRows = 16, kNCols = 8; TrainParam param; - param.max_depth = 1; - param.max_leaves = 0; + std::vector> args { + {"max_depth", "1"}, + {"max_leaves", "0"}, + }; + param.Init(args); - DeviceShard shard(0, 0, 0, kNRows, param, kNCols); + DeviceShard shard(0, 0, 0, kNRows, param, kNCols, + kNCols); BuildGidx(&shard, kNRows, kNCols, 0.9f); std::vector h_gidx_buffer(shard.gidx_buffer.size()); @@ -173,10 +179,13 @@ void TestBuildHist(bool use_shared_memory_histograms) { int const kNRows = 16, kNCols = 8; TrainParam param; - param.max_depth = 6; - param.max_leaves = 0; - - DeviceShard shard(0, 0, 0, kNRows, param, kNCols); + std::vector> args { + {"max_depth", "6"}, + {"max_leaves", "0"}, + }; + param.Init(args); + DeviceShard shard(0, 0, 0, kNRows, param, kNCols, + kNCols); BuildGidx(&shard, kNRows, kNCols); xgboost::SimpleLCG gen; @@ -263,17 +272,21 @@ TEST(GpuHist, EvaluateSplits) { constexpr int kNCols = 8; TrainParam param; - param.max_depth = 1; - param.colsample_bynode = 1; - param.colsample_bylevel = 1; - param.colsample_bytree = 1; - param.min_child_weight = 0.01; - // Disable all parameters. - param.reg_alpha = 0.0; - param.reg_lambda = 0; - param.max_delta_step = 0.0; + std::vector> args { + {"max_depth", "1"}, + {"max_leaves", "0"}, + // Disable all other parameters. + {"colsample_bynode", "1"}, + {"colsample_bylevel", "1"}, + {"colsample_bytree", "1"}, + {"min_child_weight", "0.01"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + {"max_delta_step", "0"} + }; + param.Init(args); for (size_t i = 0; i < kNCols; ++i) { param.monotone_constraints.emplace_back(0); } @@ -282,7 +295,8 @@ TEST(GpuHist, EvaluateSplits) { // Initialize DeviceShard std::unique_ptr> shard{ - new DeviceShard(0, 0, 0, kNRows, param, kNCols)}; + new DeviceShard(0, 0, 0, kNRows, param, kNCols, + kNCols)}; // Initialize DeviceShard::node_sum_gradients shard->node_sum_gradients = {{6.4f, 12.8f}}; @@ -352,14 +366,13 @@ TEST(GpuHist, ApplySplit) { TrainParam param; std::vector> args = {}; param.InitAllowUnknown(args); - // Initialize shard for (size_t i = 0; i < kNCols; ++i) { param.monotone_constraints.emplace_back(0); } - std::unique_ptr> shard{ - new DeviceShard(0, 0, 0, kNRows, param, kNCols)}; + new DeviceShard(0, 0, 0, kNRows, param, kNCols, + kNCols)}; shard->ridx_segments.resize(3); // 3 nodes. shard->node_sum_gradients.resize(3); diff --git a/tests/cpp/tree/test_split_evaluator.cc b/tests/cpp/tree/test_split_evaluator.cc new file mode 100644 index 000000000..e249b0d96 --- /dev/null +++ b/tests/cpp/tree/test_split_evaluator.cc @@ -0,0 +1,57 @@ +#include +#include +#include +#include "../../../src/tree/split_evaluator.h" + +namespace xgboost { +namespace tree { + +TEST(SplitEvaluator, Interaction) { + std::string constraints_str = R"interaction([[0, 1], [1, 2, 3]])interaction"; + std::vector> args{ + {"interaction_constraints", constraints_str}, + {"num_feature", "8"}}; + { + std::unique_ptr eval{ + SplitEvaluator::Create("elastic_net,interaction")}; + eval->Init(args); + + eval->AddSplit(0, 1, 2, /*feature_id=*/4, 0, 0); + eval->AddSplit(2, 3, 4, /*feature_id=*/5, 0, 0); + ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/0)); + ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/1)); + + ASSERT_TRUE(eval->CheckFeatureConstraint(2, /*feature_id=*/4)); + ASSERT_FALSE(eval->CheckFeatureConstraint(2, /*feature_id=*/5)); + + std::vector accepted_features; // for node 3 + for (int32_t f = 0; f < 8; ++f) { + if (eval->CheckFeatureConstraint(3, f)) { + accepted_features.emplace_back(f); + } + } + std::vector solutions{4, 5}; + ASSERT_EQ(accepted_features.size(), solutions.size()); + for (int32_t f = 0; f < accepted_features.size(); ++f) { + ASSERT_EQ(accepted_features[f], solutions[f]); + } + } + + { + std::unique_ptr eval{ + SplitEvaluator::Create("elastic_net,interaction")}; + eval->Init(args); + eval->AddSplit(/*node_id=*/0, /*left_id=*/1, /*right_id=*/2, /*feature_id=*/4, 0, 0); + std::vector accepted_features; // for node 1 + for (int32_t f = 0; f < 8; ++f) { + if (eval->CheckFeatureConstraint(1, f)) { + accepted_features.emplace_back(f); + } + } + ASSERT_EQ(accepted_features.size(), 1); + ASSERT_EQ(accepted_features[0], 4); + } +} + +} // namespace tree +} // namespace xgboost diff --git a/tests/python-gpu/test_gpu_interaction_constraints.py b/tests/python-gpu/test_gpu_interaction_constraints.py new file mode 100644 index 000000000..d0026dad1 --- /dev/null +++ b/tests/python-gpu/test_gpu_interaction_constraints.py @@ -0,0 +1,17 @@ +import numpy as np +import unittest +import sys +sys.path.append("tests/python") +# Don't import the test class, otherwise they will run twice. +import test_interaction_constraints as test_ic +rng = np.random.RandomState(1994) + + +class TestGPUInteractionConstraints(unittest.TestCase): + cputest = test_ic.TestInteractionConstraints() + + def test_interaction_constraints(self): + self.cputest.test_interaction_constraints(tree_method='gpu_hist') + + def test_training_accuracy(self): + self.cputest.test_training_accuracy(tree_method='gpu_hist') diff --git a/tests/python/test_interaction_constraints.py b/tests/python/test_interaction_constraints.py index 7cd50bf8b..8af483a76 100644 --- a/tests/python/test_interaction_constraints.py +++ b/tests/python/test_interaction_constraints.py @@ -9,27 +9,36 @@ rng = np.random.RandomState(1994) class TestInteractionConstraints(unittest.TestCase): - - def test_interaction_constraints(self): + def test_interaction_constraints(self, tree_method='hist'): x1 = np.random.normal(loc=1.0, scale=1.0, size=1000) x2 = np.random.normal(loc=1.0, scale=1.0, size=1000) x3 = np.random.choice([1, 2, 3], size=1000, replace=True) y = x1 + x2 + x3 + x1 * x2 * x3 \ - + np.random.normal(loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1) + + np.random.normal( + loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1) X = np.column_stack((x1, x2, x3)) dtrain = xgboost.DMatrix(X, label=y) - params = {'max_depth': 3, 'eta': 0.1, 'nthread': 2, 'verbosity': 0, - 'interaction_constraints': '[[0, 1]]', 'tree_method': 'hist'} - num_boost_round = 100 + params = { + 'max_depth': 3, + 'eta': 0.1, + 'nthread': 2, + 'interaction_constraints': '[[0, 1]]', + 'tree_method': tree_method, + 'verbosity': 2 + } + num_boost_round = 12 # Fit a model that only allows interaction between x1 and x2 - bst = xgboost.train(params, dtrain, num_boost_round, evals=[(dtrain, 'train')]) + bst = xgboost.train( + params, dtrain, num_boost_round, evals=[(dtrain, 'train')]) # Set all observations to have the same x3 values then increment # by the same amount def f(x): - tmat = xgboost.DMatrix(np.column_stack((x1, x2, np.repeat(x, 1000)))) + tmat = xgboost.DMatrix( + np.column_stack((x1, x2, np.repeat(x, 1000)))) return bst.predict(tmat) + preds = [f(x) for x in [1, 2, 3]] # Check incrementing x3 has the same effect on all observations @@ -40,11 +49,16 @@ class TestInteractionConstraints(unittest.TestCase): diff2 = preds[2] - preds[1] assert np.all(np.abs(diff2 - diff2[0]) < 1e-4) - def test_training_accuracy(self): + def test_training_accuracy(self, tree_method='hist'): dtrain = xgboost.DMatrix(dpath + 'agaricus.txt.train?indexing_mode=1') dtest = xgboost.DMatrix(dpath + 'agaricus.txt.test?indexing_mode=1') - params = {'eta': 1, 'max_depth': 6, 'objective': 'binary:logistic', - 'tree_method': 'hist', 'interaction_constraints': '[[1,2],[2,3,4]]'} + params = { + 'eta': 1, + 'max_depth': 6, + 'objective': 'binary:logistic', + 'tree_method': tree_method, + 'interaction_constraints': '[[1,2], [2,3,4]]' + } num_boost_round = 5 params['grow_policy'] = 'lossguide'