Use context in SetInfo. (#7687)
* Use the name `Context`. * Pass a context object into `SetInfo`. * Add context to proxy matrix. * Add context to iterative DMatrix. This is to remove the use of the default number of threads during `SetInfo` as a follow-up on removing the global omp variable while preparing for CUDA stream semantic. Currently, XGBoost uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
@@ -149,8 +149,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
group[2] = 7;
|
||||
group[3] = 5;
|
||||
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
|
||||
HistogramCuts hmat;
|
||||
|
||||
@@ -350,6 +349,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -363,7 +363,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.num_row_ = kRows;
|
||||
@@ -371,10 +371,10 @@ void TestSketchFromWeights(bool with_group) {
|
||||
|
||||
// Assign weights.
|
||||
if (with_group) {
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->Info().num_col_ = kCols;
|
||||
m->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
@@ -520,7 +520,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
|
||||
h_weights.clear();
|
||||
@@ -550,6 +550,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -563,7 +564,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.weights_.SetDevice(0);
|
||||
@@ -582,10 +583,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().num_col_ = kCols;
|
||||
dmat->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
Reference in New Issue
Block a user