Unify CPU hist sketching (#5880)
This commit is contained in:
@@ -24,10 +24,8 @@ namespace common {
|
||||
|
||||
template <typename AdapterT>
|
||||
HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
|
||||
HistogramCuts cuts;
|
||||
DenseCuts builder(&cuts);
|
||||
data::SimpleDMatrix dmat(adapter, missing, 1);
|
||||
builder.Build(&dmat, num_bins);
|
||||
HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins);
|
||||
return cuts;
|
||||
}
|
||||
|
||||
@@ -39,9 +37,7 @@ TEST(HistUtil, DeviceSketch) {
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
|
||||
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts;
|
||||
DenseCuts builder(&host_cuts);
|
||||
builder.Build(dmat.get(), num_bins);
|
||||
HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins);
|
||||
|
||||
EXPECT_EQ(device_cuts.Values(), host_cuts.Values());
|
||||
EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs());
|
||||
@@ -460,7 +456,11 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
h_weights.resize(kRows);
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
} else {
|
||||
h_weights.resize(kRows);
|
||||
}
|
||||
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
|
||||
|
||||
std::vector<bst_group_t> groups(kGroups);
|
||||
|
||||
Reference in New Issue
Block a user