Extract interaction constraint from split evaluator. (#5034)
* Extract interaction constraints from split evaluator. The reason for doing so is mostly for model IO, where num_feature and interaction_constraints are copied in split evaluator. Also interaction constraint by itself is a feature selector, acting like column sampler and it's inefficient to bury it deep in the evaluator chain. Lastly removing one another copied parameter is a win. * Enable inc for approx tree method. As now the implementation is spited up from evaluator class, it's also enabled for approx method. * Removing obsoleted code in colmaker. They are never documented nor actually used in real world. Also there isn't a single test for those code blocks. * Unifying the types used for row and column. As the size of input dataset is marching to billion, incorrect use of int is subject to overflow, also singed integer overflow is undefined behaviour. This PR starts the procedure for unifying used index type to unsigned integers. There's optimization that can utilize this undefined behaviour, but after some testings I don't see the optimization is beneficial to XGBoost.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "xgboost/base.h"
|
||||
#include "../../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -30,7 +31,6 @@ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment
|
||||
* partition training rows into different leaf nodes. */
|
||||
class RowPartitioner {
|
||||
public:
|
||||
using TreePositionT = int32_t;
|
||||
using RowIndexT = bst_uint;
|
||||
struct Segment;
|
||||
|
||||
@@ -47,8 +47,8 @@ class RowPartitioner {
|
||||
std::vector<Segment> ridx_segments;
|
||||
dh::caching_device_vector<RowIndexT> ridx_a;
|
||||
dh::caching_device_vector<RowIndexT> ridx_b;
|
||||
dh::caching_device_vector<TreePositionT> position_a;
|
||||
dh::caching_device_vector<TreePositionT> position_b;
|
||||
dh::caching_device_vector<bst_node_t> position_a;
|
||||
dh::caching_device_vector<bst_node_t> position_b;
|
||||
/*! \brief mapping for node id -> rows.
|
||||
* This looks like:
|
||||
* node id | 1 | 2 |
|
||||
@@ -56,7 +56,7 @@ class RowPartitioner {
|
||||
*/
|
||||
dh::DoubleBuffer<RowIndexT> ridx;
|
||||
/*! \brief mapping for row -> node id. */
|
||||
dh::DoubleBuffer<TreePositionT> position;
|
||||
dh::DoubleBuffer<bst_node_t> position;
|
||||
dh::caching_device_vector<int64_t>
|
||||
left_counts; // Useful to keep a bunch of zeroed memory for sort position
|
||||
std::vector<cudaStream_t> streams;
|
||||
@@ -70,7 +70,7 @@ class RowPartitioner {
|
||||
/**
|
||||
* \brief Gets the row indices of training instances in a given node.
|
||||
*/
|
||||
common::Span<const RowIndexT> GetRows(TreePositionT nidx);
|
||||
common::Span<const RowIndexT> GetRows(bst_node_t nidx);
|
||||
|
||||
/**
|
||||
* \brief Gets all training rows in the set.
|
||||
@@ -80,17 +80,17 @@ class RowPartitioner {
|
||||
/**
|
||||
* \brief Gets the tree position of all training instances.
|
||||
*/
|
||||
common::Span<const TreePositionT> GetPosition();
|
||||
common::Span<const bst_node_t> GetPosition();
|
||||
|
||||
/**
|
||||
* \brief Convenience method for testing
|
||||
*/
|
||||
std::vector<RowIndexT> GetRowsHost(TreePositionT nidx);
|
||||
std::vector<RowIndexT> GetRowsHost(bst_node_t nidx);
|
||||
|
||||
/**
|
||||
* \brief Convenience method for testing
|
||||
*/
|
||||
std::vector<TreePositionT> GetPositionHost();
|
||||
std::vector<bst_node_t> GetPositionHost();
|
||||
|
||||
/**
|
||||
* \brief Updates the tree position for set of training instances being split
|
||||
@@ -105,8 +105,8 @@ class RowPartitioner {
|
||||
* argument and return the new position for this training instance.
|
||||
*/
|
||||
template <typename UpdatePositionOpT>
|
||||
void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, UpdatePositionOpT op) {
|
||||
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
|
||||
bst_node_t right_nidx, UpdatePositionOpT op) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
Segment segment = ridx_segments.at(nidx); // rows belongs to node nidx
|
||||
auto d_ridx = ridx.CurrentSpan();
|
||||
@@ -123,7 +123,7 @@ class RowPartitioner {
|
||||
// LaunchN starts from zero, so we restore the row index by adding segment.begin
|
||||
idx += segment.begin;
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
TreePositionT new_position = op(ridx); // new node id
|
||||
bst_node_t new_position = op(ridx); // new node id
|
||||
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
|
||||
AtomicIncrement(d_left_count, new_position == left_nidx);
|
||||
d_position[idx] = new_position;
|
||||
@@ -172,16 +172,16 @@ class RowPartitioner {
|
||||
* segments. Based on a single pass of exclusive scan, uses iterators to
|
||||
* redirect inputs and outputs.
|
||||
*/
|
||||
void SortPosition(common::Span<TreePositionT> position,
|
||||
common::Span<TreePositionT> position_out,
|
||||
void SortPosition(common::Span<bst_node_t> position,
|
||||
common::Span<bst_node_t> position_out,
|
||||
common::Span<RowIndexT> ridx,
|
||||
common::Span<RowIndexT> ridx_out, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, int64_t* d_left_count,
|
||||
common::Span<RowIndexT> ridx_out, bst_node_t left_nidx,
|
||||
bst_node_t right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream = nullptr);
|
||||
|
||||
/*! \brief Sort row indices according to position. */
|
||||
void SortPositionAndCopy(const Segment& segment, TreePositionT left_nidx,
|
||||
TreePositionT right_nidx, int64_t* d_left_count,
|
||||
void SortPositionAndCopy(const Segment& segment, bst_node_t left_nidx,
|
||||
bst_node_t right_nidx, int64_t* d_left_count,
|
||||
cudaStream_t stream);
|
||||
/** \brief Used to demarcate a contiguous set of row indices associated with
|
||||
* some tree node. */
|
||||
|
||||
Reference in New Issue
Block a user