Further optimisations for gpu_hist. (#4283)

- Fuse final update position functions into a single more efficient kernel

- Refactor gpu_hist with a more explicit ellpack  matrix representation
This commit is contained in:
Rory Mitchell
2019-03-24 17:17:22 +13:00
committed by GitHub
parent 5aa42b5f11
commit 6d5b34d824
5 changed files with 345 additions and 297 deletions

View File

@@ -39,8 +39,9 @@ void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f};
shard->InitRowPtrs(batch);
shard->InitCompressedData(cmat, batch);
auto is_dense = (*dmat)->Info().num_nonzero_ ==
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
shard->InitCompressedData(cmat, batch, is_dense);
delete dmat;
}
@@ -59,7 +60,7 @@ TEST(GpuHist, BuildGidxDense) {
h_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_EQ(shard.row_stride, kNCols);
ASSERT_EQ(shard.ellpack_matrix.row_stride, kNCols);
std::vector<uint32_t> solution = {
0, 3, 8, 9, 14, 17, 20, 21,
@@ -98,7 +99,7 @@ TEST(GpuHist, BuildGidxSparse) {
h_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(h_gidx_buffer.data(), 25);
ASSERT_LE(shard.row_stride, 3);
ASSERT_LE(shard.ellpack_matrix.row_stride, 3);
// row_stride = 3, 16 rows, 48 entries for ELLPack
std::vector<uint32_t> solution = {
@@ -106,7 +107,7 @@ TEST(GpuHist, BuildGidxSparse) {
24, 24, 24, 24, 24, 5, 24, 24, 0, 16, 24, 15, 24, 24, 24, 24,
24, 7, 14, 16, 4, 24, 24, 24, 24, 24, 9, 24, 24, 1, 24, 24
};
for (size_t i = 0; i < kNRows * shard.row_stride; ++i) {
for (size_t i = 0; i < kNRows * shard.ellpack_matrix.row_stride; ++i) {
ASSERT_EQ(solution[i], gidx[i]);
}
}
@@ -256,16 +257,19 @@ TEST(GpuHist, EvaluateSplits) {
common::HistCutMatrix cmat = GetHostCutMatrix();
// Copy cut matrix to device.
DeviceShard<GradientPairPrecise>::DeviceHistCutMatrix cut;
shard->ba.Allocate(0,
&(shard->d_cut.feature_segments), cmat.row_ptr.size(),
&(shard->d_cut.min_fvalue), cmat.min_val.size(),
&(shard->d_cut.gidx_fvalue_map), 24,
&(shard->feature_segments), cmat.row_ptr.size(),
&(shard->min_fvalue), cmat.min_val.size(),
&(shard->gidx_fvalue_map), 24,
&(shard->monotone_constraints), kNCols);
shard->d_cut.feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
shard->d_cut.gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
shard->monotone_constraints.copy(param.monotone_constraints.begin(),
param.monotone_constraints.end());
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
// Initialize DeviceShard::hist
shard->hist.Init(0, (max_bins - 1) * kNCols);
@@ -339,7 +343,7 @@ TEST(GpuHist, ApplySplit) {
shard->ridx_segments[0] = Segment(0, kNRows);
shard->ba.Allocate(0, &(shard->ridx), kNRows,
&(shard->position), kNRows);
shard->row_stride = kNCols;
shard->ellpack_matrix.row_stride = kNCols;
thrust::sequence(shard->ridx.CurrentDVec().tbegin(),
shard->ridx.CurrentDVec().tend());
// Initialize GPUHistMaker
@@ -351,11 +355,9 @@ TEST(GpuHist, ApplySplit) {
0.59, 4, // fvalue has to be equal to one of the cut field
GradientPair(8.2, 2.8), GradientPair(6.3, 3.6),
GPUTrainingParam(param));
GPUHistMakerSpecialised<GradientPairPrecise>::ExpandEntry candidate_entry {0, 0, candidate, 0};
ExpandEntry candidate_entry {0, 0, candidate, 0};
candidate_entry.nid = kNId;
auto const& nodes = tree.GetNodes();
// Used to get bin_id in update position.
common::HistCutMatrix cmat = GetHostCutMatrix();
hist_maker.hmat_ = cmat;
@@ -370,19 +372,31 @@ TEST(GpuHist, ApplySplit) {
int row_stride = kNCols;
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(
row_stride * kNRows, num_symbols);
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes);
common::CompressedBufferWriter::CalculateBufferSize(row_stride * kNRows,
num_symbols);
shard->ba.Allocate(0, &(shard->gidx_buffer), compressed_size_bytes,
&(shard->feature_segments), cmat.row_ptr.size(),
&(shard->min_fvalue), cmat.min_val.size(),
&(shard->gidx_fvalue_map), 24);
shard->feature_segments.copy(cmat.row_ptr.begin(), cmat.row_ptr.end());
shard->gidx_fvalue_map.copy(cmat.cut.begin(), cmat.cut.end());
shard->ellpack_matrix.feature_segments = shard->feature_segments.GetSpan();
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map.GetSpan();
shard->min_fvalue.copy(cmat.min_val.begin(), cmat.min_val.end());
shard->ellpack_matrix.min_fvalue = shard->min_fvalue.GetSpan();
shard->ellpack_matrix.is_dense = true;
common::CompressedBufferWriter wr(num_symbols);
std::vector<int> h_gidx (kNRows * row_stride);
std::iota(h_gidx.begin(), h_gidx.end(), 0);
// gidx 14 should go right, 12 goes left
std::vector<int> h_gidx (kNRows * row_stride, 14);
h_gidx[4] = 12;
h_gidx[12] = 12;
std::vector<common::CompressedByteT> h_gidx_compressed (compressed_size_bytes);
wr.Write(h_gidx_compressed.data(), h_gidx.begin(), h_gidx.end());
shard->gidx_buffer.copy(h_gidx_compressed.begin(), h_gidx_compressed.end());
shard->gidx = common::CompressedIterator<uint32_t>(
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
shard->gidx_buffer.Data(), num_symbols);
hist_maker.info_ = &info;
@@ -395,8 +409,8 @@ TEST(GpuHist, ApplySplit) {
int right_nidx = tree[kNId].RightChild();
ASSERT_EQ(shard->ridx_segments[left_nidx].begin, 0);
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 6);
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 6);
ASSERT_EQ(shard->ridx_segments[left_nidx].end, 2);
ASSERT_EQ(shard->ridx_segments[right_nidx].begin, 2);
ASSERT_EQ(shard->ridx_segments[right_nidx].end, 16);
}
@@ -417,7 +431,7 @@ void TestSortPosition(const std::vector<int>& position_in, int left_idx,
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<bst_uint>(ridx.data().get(), ridx.size()),
common::Span<bst_uint>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, d_left_count.data().get());
right_idx, d_left_count.data().get(), nullptr);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;