Refactor parts of fast histogram utilities (#3564)

* Refactor parts of fast histogram utilities

* Removed byte packing from column matrix
This commit is contained in:
Rory Mitchell
2018-08-09 17:59:57 +12:00
committed by GitHub
parent 3c72654e3b
commit bbb771f32e
8 changed files with 184 additions and 288 deletions

View File

@@ -18,11 +18,8 @@ TEST(gpu_hist_experimental, TestSparseShard) {
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0.9f);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
gmat.Init(dmat.get(),max_bins);
TrainParam p;
p.max_depth = 6;
@@ -32,7 +29,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
const SparsePage& batch = iter->Value();
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
shard.InitCompressedData(gmat.cut, batch);
CHECK(!iter->Next());
ASSERT_LT(shard.row_stride, columns);
@@ -40,7 +37,7 @@ TEST(gpu_hist_experimental, TestSparseShard) {
auto host_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
gmat.cut.row_ptr.back() + 1);
for (int i = 0; i < rows; i++) {
int row_offset = 0;
@@ -60,11 +57,8 @@ TEST(gpu_hist_experimental, TestDenseShard) {
int columns = 80;
int max_bins = 4;
auto dmat = CreateDMatrix(rows, columns, 0);
common::HistCutMatrix hmat;
common::GHistIndexMatrix gmat;
hmat.Init(dmat.get(), max_bins);
gmat.cut = &hmat;
gmat.Init(dmat.get());
gmat.Init(dmat.get(),max_bins);
TrainParam p;
p.max_depth = 6;
@@ -75,7 +69,7 @@ TEST(gpu_hist_experimental, TestDenseShard) {
DeviceShard shard(0, 0, 0, rows, p);
shard.InitRowPtrs(batch);
shard.InitCompressedData(hmat, batch);
shard.InitCompressedData(gmat.cut, batch);
CHECK(!iter->Next());
ASSERT_EQ(shard.row_stride, columns);
@@ -83,7 +77,7 @@ TEST(gpu_hist_experimental, TestDenseShard) {
auto host_gidx_buffer = shard.gidx_buffer.AsVector();
common::CompressedIterator<uint32_t> gidx(host_gidx_buffer.data(),
hmat.row_ptr.back() + 1);
gmat.cut.row_ptr.back() + 1);
for (int i = 0; i < gmat.index.size(); i++) {
ASSERT_EQ(gidx[i], gmat.index[i]);