Fix CPU hist init for sparse dataset. (#4625)

* Fix CPU hist init for sparse dataset.

* Implement sparse histogram cut.
* Allow empty features.

* Fix windows build, don't use sparse in distributed environment.

* Comments.

* Smaller threshold.

* Fix windows omp.

* Fix msvc lambda capture.

* Fix MSVC macro.

* Fix MSVC initialization list.

* Fix MSVC initialization list x2.

* Preserve categorical feature behavior.

* Rename matrix to sparse cuts.
* Reuse UseGroup.
* Check for categorical data when adding cut.

Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>

* Sanity check.

* Fix comments.

* Fix comment.
This commit is contained in:
Jiaming Yuan
2019-07-04 19:27:03 -04:00
committed by Philip Hyunsu Cho
parent b7a1f22d24
commit d9a47794a5
33 changed files with 681 additions and 299 deletions

View File

@@ -53,10 +53,10 @@ TEST(c_api, XGDMatrixCreateFromMat_omp) {
ASSERT_EQ(info.num_nonzero_, num_cols * row - num_missing);
for (const auto &batch : (*dmat)->GetRowBatches()) {
for (int i = 0; i < batch.Size(); i++) {
for (size_t i = 0; i < batch.Size(); i++) {
auto inst = batch[i];
for (int j = 0; i < inst.size(); i++) {
ASSERT_EQ(inst[j].fvalue, 1.5);
for (auto e : inst) {
ASSERT_EQ(e.fvalue, 1.5);
}
}
}

View File

@@ -7,6 +7,7 @@
namespace xgboost {
namespace common {
TEST(DenseColumn, Test) {
auto dmat = CreateDMatrix(100, 10, 0.0);
GHistIndexMatrix gmat;
@@ -17,7 +18,7 @@ TEST(DenseColumn, Test) {
for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) {
for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) {
auto col = column_matrix.GetColumn(j);
EXPECT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
ASSERT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
col.GetGlobalBinIdx(i));
}
}
@@ -33,7 +34,7 @@ TEST(SparseColumn, Test) {
auto col = column_matrix.GetColumn(0);
ASSERT_EQ(col.Size(), gmat.index.size());
for (auto i = 0ull; i < col.Size(); i++) {
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
}
delete dmat;

View File

@@ -28,7 +28,7 @@ TEST(CompressedIterator, Test) {
CompressedIterator<int> ci(buffer.data(), alphabet_size);
std::vector<int> output(input.size());
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
output[i] = ci[i];
}
@@ -38,12 +38,12 @@ TEST(CompressedIterator, Test) {
std::vector<unsigned char> buffer2(
CompressedBufferWriter::CalculateBufferSize(input.size(),
alphabet_size));
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
cbw.WriteSymbol(buffer2.data(), input[i], i);
}
CompressedIterator<int> ci2(buffer.data(), alphabet_size);
std::vector<int> output2(input.size());
for (int i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++) {
output2[i] = ci2[i];
}
ASSERT_TRUE(input == output2);

View File

@@ -48,11 +48,11 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
int gpu_batch_nrows = 0;
// find quantiles on the CPU
HistCutMatrix hmat_cpu;
hmat_cpu.Init((*dmat).get(), p.max_bin);
HistogramCuts hmat_cpu;
hmat_cpu.Build((*dmat).get(), p.max_bin);
// find the cuts on the GPU
HistCutMatrix hmat_gpu;
HistogramCuts hmat_gpu;
size_t row_stride = DeviceSketch(p, CreateEmptyGenericParam(0, devices.Size()), gpu_batch_nrows,
dmat->get(), &hmat_gpu);
@@ -69,12 +69,12 @@ void TestDeviceSketch(const GPUSet& devices, bool use_external_memory) {
// compare the cuts
double eps = 1e-2;
ASSERT_EQ(hmat_gpu.min_val.size(), num_cols);
ASSERT_EQ(hmat_gpu.row_ptr.size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.cut.size(), hmat_cpu.cut.size());
ASSERT_LT(fabs(hmat_cpu.min_val[0] - hmat_gpu.min_val[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.cut.size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.cut[i] - hmat_gpu.cut[i]), eps * nrows);
ASSERT_EQ(hmat_gpu.MinValues().size(), num_cols);
ASSERT_EQ(hmat_gpu.Ptrs().size(), num_cols + 1);
ASSERT_EQ(hmat_gpu.Values().size(), hmat_cpu.Values().size());
ASSERT_LT(fabs(hmat_cpu.MinValues()[0] - hmat_gpu.MinValues()[0]), eps * nrows);
for (int i = 0; i < hmat_gpu.Values().size(); ++i) {
ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows);
}
delete dmat;

View File

@@ -9,15 +9,7 @@
namespace xgboost {
namespace common {
class HistCutMatrixMock : public HistCutMatrix {
public:
size_t SearchGroupIndFromBaseRow(
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid);
}
};
TEST(HistCutMatrix, SearchGroupInd) {
TEST(CutsBuilder, SearchGroupInd) {
size_t constexpr kNumGroups = 4;
size_t constexpr kNumRows = 17;
size_t constexpr kNumCols = 15;
@@ -34,18 +26,102 @@ TEST(HistCutMatrix, SearchGroupInd) {
p_mat->Info().SetInfo(
"group", group.data(), DataType::kUInt32, kNumGroups);
HistCutMatrixMock hmat;
HistogramCuts hmat;
size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0);
size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0);
ASSERT_EQ(group_ind, 0);
group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5);
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
ASSERT_EQ(group_ind, 2);
EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17));
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
delete pp_mat;
}
namespace {
class SparseCutsWrapper : public SparseCuts {
public:
std::vector<uint32_t> const& ColPtrs() const { return p_cuts_->Ptrs(); }
std::vector<float> const& ColValues() const { return p_cuts_->Values(); }
};
} // anonymous namespace
TEST(SparseCuts, SingleThreadedBuild) {
size_t constexpr kRows = 267;
size_t constexpr kCols = 31;
size_t constexpr kBins = 256;
// Dense matrix.
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
DMatrix* p_fmat = (*pp_mat).get();
common::GHistIndexMatrix hmat;
hmat.Init(p_fmat, kBins);
HistogramCuts cuts;
SparseCuts indices(&cuts);
auto const& page = *(p_fmat->GetColumnBatches().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size());
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
delete pp_mat;
}
TEST(SparseCuts, MultiThreadedBuild) {
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;
size_t constexpr kBins = 255;
omp_ulong ori_nthreads = omp_get_max_threads();
omp_set_num_threads(16);
auto Compare =
#if defined(_MSC_VER) // msvc fails to capture
[kBins](DMatrix* p_fmat) {
#else
[](DMatrix* p_fmat) {
#endif
HistogramCuts threaded_container;
SparseCuts threaded_indices(&threaded_container);
threaded_indices.Build(p_fmat, kBins);
HistogramCuts container;
SparseCuts indices(&container);
auto const& page = *(p_fmat->GetColumnBatches().begin());
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size());
ASSERT_EQ(container.Values().size(), threaded_container.Values().size());
for (uint32_t i = 0; i < container.Ptrs().size(); ++i) {
ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]);
}
for (uint32_t i = 0; i < container.Values().size(); ++i) {
ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]);
}
};
{
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
DMatrix* p_fmat = (*pp_mat).get();
Compare(p_fmat);
delete pp_mat;
}
{
auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001);
DMatrix* p_fmat = (*pp_mat).get();
Compare(p_fmat);
delete pp_mat;
}
omp_set_num_threads(ori_nthreads);
}
} // namespace common
} // namespace xgboost

View File

@@ -53,8 +53,8 @@ TEST(ColumnSampler, Test) {
TEST(ColumnSampler, ThreadSynchronisation) {
const int64_t num_threads = 100;
int n = 128;
int iterations = 10;
int levels = 5;
size_t iterations = 10;
size_t levels = 5;
std::vector<int> reference_result;
bool success =
true; // Cannot use google test asserts in multithreaded region

View File

@@ -310,7 +310,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(first.size(), 4);
ASSERT_EQ(first.data(), arr);
for (size_t i = 0; i < first.size(); ++i) {
for (int64_t i = 0; i < first.size(); ++i) {
ASSERT_EQ(first[i], arr[i]);
}
@@ -329,7 +329,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(last.size(), 4);
ASSERT_EQ(last.data(), arr + 12);
for (size_t i = 0; i < last.size(); ++i) {
for (int64_t i = 0; i < last.size(); ++i) {
ASSERT_EQ(last[i], arr[i+12]);
}
@@ -348,7 +348,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(first.size(), 4);
ASSERT_EQ(first.data(), s.data());
for (size_t i = 0; i < first.size(); ++i) {
for (int64_t i = 0; i < first.size(); ++i) {
ASSERT_EQ(first[i], s[i]);
}
@@ -368,7 +368,7 @@ TEST(Span, FirstLast) {
ASSERT_EQ(last.size(), 4);
ASSERT_EQ(last.data(), s.data() + 12);
for (size_t i = 0; i < last.size(); ++i) {
for (int64_t i = 0; i < last.size(); ++i) {
ASSERT_EQ(s[12 + i], last[i]);
}

View File

@@ -50,7 +50,7 @@ TEST(SparsePage, PushCSC) {
inst = page[1];
ASSERT_EQ(inst.size(), 6);
std::vector<size_t> indices_sol {1, 2, 3};
for (size_t i = 0; i < inst.size(); ++i) {
for (int64_t i = 0; i < inst.size(); ++i) {
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
}
}

View File

@@ -21,13 +21,13 @@ TEST(cpu_predictor, Test) {
HostDeviceVector<float> out_predictions;
cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0);
std::vector<float>& out_predictions_h = out_predictions.HostVector();
for (int i = 0; i < out_predictions.Size(); i++) {
for (size_t i = 0; i < out_predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5);
}
// Test predict instance
auto &batch = *(*dmat)->GetRowBatches().begin();
for (int i = 0; i < batch.Size(); i++) {
for (size_t i = 0; i < batch.Size(); i++) {
std::vector<float> instance_out_predictions;
cpu_predictor->PredictInstance(batch[i], &instance_out_predictions, model);
ASSERT_EQ(instance_out_predictions[0], 1.5);

View File

@@ -94,7 +94,7 @@ void TestUpdatePosition() {
}
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
void TestFinalise() {
const int kNumRows = 10;
RowPartitioner rp(0, kNumRows);

View File

@@ -53,27 +53,43 @@ TEST(GpuHist, DeviceHistogram) {
}
}
};
}
namespace {
class HistogramCutsWrapper : public common::HistogramCuts {
public:
using SuperT = common::HistogramCuts;
void SetValues(std::vector<float> cuts) {
SuperT::cut_values_ = cuts;
}
void SetPtrs(std::vector<uint32_t> ptrs) {
SuperT::cut_ptrs_ = ptrs;
}
void SetMins(std::vector<float> mins) {
SuperT::min_vals_ = mins;
}
};
} // anonymous namespace
template <typename GradientSumT>
void BuildGidx(DeviceShard<GradientSumT>* shard, int n_rows, int n_cols,
bst_float sparsity=0) {
auto dmat = CreateDMatrix(n_rows, n_cols, sparsity, 3);
const SparsePage& batch = *(*dmat)->GetRowBatches().begin();
common::HistCutMatrix cmat;
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
// 24 cut fields, 3 cut fields for each feature (column).
cmat.cut = {0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f};
cmat.SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
auto is_dense = (*dmat)->Info().num_nonzero_ ==
(*dmat)->Info().num_row_ * (*dmat)->Info().num_col_;
@@ -241,20 +257,20 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true);
}
common::HistCutMatrix GetHostCutMatrix () {
common::HistCutMatrix cmat;
cmat.row_ptr = {0, 3, 6, 9, 12, 15, 18, 21, 24};
cmat.min_val = {0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f};
HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f});
// 24 cut fields, 3 cut fields for each feature (column).
// Each row of the cut represents the cuts for a data column.
cmat.cut = {0.30f, 0.67f, 1.64f,
cmat.SetValues({0.30f, 0.67f, 1.64f,
0.32f, 0.77f, 1.95f,
0.29f, 0.70f, 1.80f,
0.32f, 0.75f, 1.85f,
0.18f, 0.59f, 1.69f,
0.25f, 0.74f, 2.00f,
0.26f, 0.74f, 1.98f,
0.26f, 0.71f, 1.83f};
0.26f, 0.71f, 1.83f});
return cmat;
}
@@ -293,21 +309,21 @@ TEST(GpuHist, EvaluateSplits) {
shard->node_sum_gradients = {{6.4f, 12.8f}};
// Initialize DeviceShard::cut
common::HistCutMatrix cmat = GetHostCutMatrix();
auto cmat = GetHostCutMatrix();
// Copy cut matrix to device.
shard->ba.Allocate(0,
&(shard->feature_segments), cmat.row_ptr.size(),
&(shard->min_fvalue), cmat.min_val.size(),
&(shard->feature_segments), cmat.Ptrs().size(),
&(shard->min_fvalue), cmat.MinValues().size(),
&(shard->gidx_fvalue_map), 24,
&(shard->monotone_constraints), kNCols);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.row_ptr);
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.cut);
dh::CopyVectorToDeviceSpan(shard->feature_segments, cmat.Ptrs());
dh::CopyVectorToDeviceSpan(shard->gidx_fvalue_map, cmat.Values());
dh::CopyVectorToDeviceSpan(shard->monotone_constraints,
param.monotone_constraints);
shard->ellpack_matrix.feature_segments = shard->feature_segments;
shard->ellpack_matrix.gidx_fvalue_map = shard->gidx_fvalue_map;
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.min_val);
dh::CopyVectorToDeviceSpan(shard->min_fvalue, cmat.MinValues());
shard->ellpack_matrix.min_fvalue = shard->min_fvalue;
// Initialize DeviceShard::hist

View File

@@ -13,7 +13,7 @@ namespace xgboost {
namespace tree {
TEST(Updater, Prune) {
int constexpr kNRows = 32, kNCols = 16;
int constexpr kNCols = 16;
std::vector<std::pair<std::string, std::string>> cfg;
cfg.emplace_back(std::pair<std::string, std::string>(

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2018 by Contributors
* Copyright 2018-2019 by Contributors
*/
#include "../helpers.h"
#include "../../../src/tree/param.h"
@@ -46,23 +46,25 @@ class QuantileHistMock : public QuantileHistMaker {
const size_t num_row = p_fmat->Info().num_row_;
const size_t num_col = p_fmat->Info().num_col_;
/* Validate HistCutMatrix */
ASSERT_EQ(gmat.cut.row_ptr.size(), num_col + 1);
ASSERT_EQ(gmat.cut.Ptrs().size(), num_col + 1);
for (size_t fid = 0; fid < num_col; ++fid) {
// Each feature must have at least one quantile point (cut)
const size_t ibegin = gmat.cut.row_ptr[fid];
const size_t iend = gmat.cut.row_ptr[fid + 1];
ASSERT_LT(ibegin, iend);
const size_t ibegin = gmat.cut.Ptrs()[fid];
const size_t iend = gmat.cut.Ptrs()[fid + 1];
// Ordered, but empty feature is allowed.
ASSERT_LE(ibegin, iend);
for (size_t i = ibegin; i < iend - 1; ++i) {
// Quantile points must be sorted in ascending order
// No duplicates allowed
ASSERT_LT(gmat.cut.cut[i], gmat.cut.cut[i + 1]);
ASSERT_LT(gmat.cut.Values()[i], gmat.cut.Values()[i + 1])
<< "ibegin: " << ibegin << ", "
<< "iend: " << iend;
}
}
/* Validate GHistIndexMatrix */
ASSERT_EQ(gmat.row_ptr.size(), num_row + 1);
ASSERT_LT(*std::max_element(gmat.index.begin(), gmat.index.end()),
gmat.cut.row_ptr.back());
gmat.cut.Ptrs().back());
for (const auto& batch : p_fmat->GetRowBatches()) {
for (size_t i = 0; i < batch.Size(); ++i) {
const size_t rid = batch.base_rowid + i;
@@ -71,20 +73,20 @@ class QuantileHistMock : public QuantileHistMaker {
ASSERT_LT(gmat_row_offset, gmat.index.size());
SparsePage::Inst inst = batch[i];
ASSERT_EQ(gmat.row_ptr[rid] + inst.size(), gmat.row_ptr[rid + 1]);
for (size_t j = 0; j < inst.size(); ++j) {
for (int64_t j = 0; j < inst.size(); ++j) {
// Each entry of GHistIndexMatrix represents a bin ID
const size_t bin_id = gmat.index[gmat_row_offset + j];
const size_t fid = inst[j].index;
// The bin ID must correspond to correct feature
ASSERT_GE(bin_id, gmat.cut.row_ptr[fid]);
ASSERT_LT(bin_id, gmat.cut.row_ptr[fid + 1]);
ASSERT_GE(bin_id, gmat.cut.Ptrs()[fid]);
ASSERT_LT(bin_id, gmat.cut.Ptrs()[fid + 1]);
// The bin ID must correspond to a region between two
// suitable quantile points
ASSERT_LT(inst[j].fvalue, gmat.cut.cut[bin_id]);
if (bin_id > gmat.cut.row_ptr[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.cut[bin_id - 1]);
ASSERT_LT(inst[j].fvalue, gmat.cut.Values()[bin_id]);
if (bin_id > gmat.cut.Ptrs()[fid]) {
ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]);
} else {
ASSERT_GE(inst[j].fvalue, gmat.cut.min_val[fid]);
ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]);
}
}
}
@@ -106,11 +108,12 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<std::vector<uint8_t>> hist_is_init;
std::vector<ExpandEntry> nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)};
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, gpair, &hist_buffers, &hist_is_init);
RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
RealImpl::InitNewNode(nid, gmat, gpair, fmat,
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers);
// Check if number of histogram bins is correct
ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back());
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
// Compute the correct histogram (histogram_expected)
@@ -126,7 +129,7 @@ class QuantileHistMock : public QuantileHistMaker {
}
// Now validate the computed histogram returned by BuildHist
for (size_t i = 0; i < hist_[nid].size(); ++i) {
for (int64_t i = 0; i < hist_[nid].size(); ++i) {
GradientPairPrecise sol = histogram_expected[i];
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
@@ -140,7 +143,7 @@ class QuantileHistMock : public QuantileHistMaker {
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
size_t constexpr kMaxBins = 4;
auto dmat = CreateDMatrix(kNRows, kNCols, 0, 3);
// dense, no missing values
// dense, no missing values
common::GHistIndexMatrix gmat;
gmat.Init((*dmat).get(), kMaxBins);
@@ -152,7 +155,8 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<std::vector<float*>> hist_buffers;
std::vector<std::vector<uint8_t>> hist_is_init;
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init);
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat),
const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
/* Compute correct split (best_split) using the computed histogram */
@@ -178,8 +182,8 @@ class QuantileHistMock : public QuantileHistMaker {
size_t best_split_feature = std::numeric_limits<size_t>::max();
// Enumerate all features
for (size_t fid = 0; fid < num_feature; ++fid) {
const size_t bin_id_min = gmat.cut.row_ptr[fid];
const size_t bin_id_max = gmat.cut.row_ptr[fid + 1];
const size_t bin_id_min = gmat.cut.Ptrs()[fid];
const size_t bin_id_max = gmat.cut.Ptrs()[fid + 1];
// Enumerate all bin ID in [bin_id_min, bin_id_max), i.e. every possible
// choice of thresholds for feature fid
for (size_t split_thresh = bin_id_min;
@@ -217,7 +221,7 @@ class QuantileHistMock : public QuantileHistMaker {
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
delete dmat;
}