Use context in SetInfo. (#7687)
* Use the name `Context`. * Pass a context object into `SetInfo`. * Add context to proxy matrix. * Add context to iterative DMatrix. This is to remove the use of the default number of threads during `SetInfo` as a follow-up on removing the global omp variable while preparing for CUDA stream semantic. Currently, XGBoost uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
@@ -149,8 +149,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
group[2] = 7;
|
||||
group[3] = 5;
|
||||
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
|
||||
HistogramCuts hmat;
|
||||
|
||||
@@ -350,6 +349,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -363,7 +363,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.num_row_ = kRows;
|
||||
@@ -371,10 +371,10 @@ void TestSketchFromWeights(bool with_group) {
|
||||
|
||||
// Assign weights.
|
||||
if (with_group) {
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->Info().num_col_ = kCols;
|
||||
m->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
@@ -520,7 +520,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
|
||||
h_weights.clear();
|
||||
@@ -550,6 +550,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@@ -563,7 +564,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.weights_.SetDevice(0);
|
||||
@@ -582,10 +583,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().num_col_ = kCols;
|
||||
dmat->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
@@ -12,28 +12,29 @@
|
||||
#include "xgboost/base.h"
|
||||
|
||||
TEST(MetaInfo, GetSet) {
|
||||
xgboost::Context ctx;
|
||||
xgboost::MetaInfo info;
|
||||
|
||||
double double2[2] = {1.0, 2.0};
|
||||
|
||||
EXPECT_EQ(info.labels.Size(), 0);
|
||||
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2);
|
||||
info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.labels.Size(), 2);
|
||||
|
||||
float float2[2] = {1.0f, 2.0f};
|
||||
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
||||
<< "When no weights are given, was expecting default value 1";
|
||||
info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2);
|
||||
info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.GetWeight(1), 2.0f);
|
||||
|
||||
uint32_t uint32_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.base_margin_.Size(), 0);
|
||||
info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||
info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||
EXPECT_EQ(info.base_margin_.Size(), 2);
|
||||
|
||||
uint64_t uint64_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.group_ptr_.size(), 0);
|
||||
info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||
info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||
ASSERT_EQ(info.group_ptr_.size(), 3);
|
||||
EXPECT_EQ(info.group_ptr_[2], 3);
|
||||
|
||||
@@ -73,6 +74,8 @@ TEST(MetaInfo, GetSetFeature) {
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
|
||||
uint64_t constexpr kRows { 64 }, kCols { 32 };
|
||||
auto generator = []() {
|
||||
static float f = 0;
|
||||
@@ -80,9 +83,9 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
};
|
||||
std::vector<float> values (kRows);
|
||||
std::generate(values.begin(), values.end(), generator);
|
||||
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
@@ -210,13 +213,14 @@ TEST(MetaInfo, LoadQid) {
|
||||
|
||||
TEST(MetaInfo, CPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
info.num_row_ = 100;
|
||||
std::vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
qid[i] = i;
|
||||
}
|
||||
|
||||
info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
|
||||
info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
@@ -232,12 +236,15 @@ TEST(MetaInfo, Validate) {
|
||||
info.num_nonzero_ = 12;
|
||||
info.num_col_ = 3;
|
||||
std::vector<xgboost::bst_group_t> groups (11);
|
||||
info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11);
|
||||
xgboost::Context ctx;
|
||||
info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
std::vector<float> labels(info.num_row_ + 1);
|
||||
EXPECT_THROW(
|
||||
{ info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); },
|
||||
{
|
||||
info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||
},
|
||||
dmlc::Error);
|
||||
|
||||
// Make overflow data, which can happen when users pass group structure as int
|
||||
@@ -247,14 +254,13 @@ TEST(MetaInfo, Validate) {
|
||||
groups.push_back(1562500);
|
||||
}
|
||||
groups.push_back(static_cast<xgboost::bst_group_t>(-1));
|
||||
EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32,
|
||||
groups.size()),
|
||||
EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()),
|
||||
dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
info.group_ptr_.clear();
|
||||
labels.resize(info.num_row_);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.labels.SetDevice(0);
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
|
||||
@@ -263,12 +269,13 @@ TEST(MetaInfo, Validate) {
|
||||
d_groups.DevicePointer(); // pull to device
|
||||
std::string arr_interface_str{ArrayInterfaceStr(
|
||||
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
|
||||
EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||
EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
TEST(MetaInfo, HostExtend) {
|
||||
xgboost::MetaInfo lhs, rhs;
|
||||
xgboost::Context ctx;
|
||||
size_t const kRows = 100;
|
||||
lhs.labels.Reshape(kRows);
|
||||
lhs.num_row_ = kRows;
|
||||
@@ -282,8 +289,8 @@ TEST(MetaInfo, HostExtend) {
|
||||
for (size_t g = 0; g < kRows / per_group; ++g) {
|
||||
groups.emplace_back(per_group);
|
||||
}
|
||||
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
|
||||
lhs.Extend(rhs, true, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
@@ -300,5 +307,5 @@ TEST(MetaInfo, HostExtend) {
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); }
|
||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); }
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -25,14 +25,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
|
||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||
column["shape"] = Array(j_shape);
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer{static_cast<Integer::Int>(sizeof(T))})});
|
||||
column["version"] = 3;
|
||||
column["typestr"] = String(typestr);
|
||||
|
||||
auto p_d_data = d_data.data().get();
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
|
||||
Json(Boolean(false))};
|
||||
std::vector<Json> j_data{Json(Integer{reinterpret_cast<Integer::Int>(p_d_data)}),
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
column["stream"] = nullptr;
|
||||
Json array(std::vector<Json>{column});
|
||||
@@ -45,12 +44,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
Context ctx;
|
||||
thrust::device_vector<float> d_data;
|
||||
|
||||
std::string str = PrepareData<float>("<f4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
info.SetInfo("label", str.c_str());
|
||||
info.SetInfo(ctx, "label", str.c_str());
|
||||
|
||||
auto const& h_label = info.labels.HostView();
|
||||
ASSERT_EQ(h_label.Size(), d_data.size());
|
||||
@@ -58,13 +58,13 @@ TEST(MetaInfo, FromInterface) {
|
||||
ASSERT_EQ(h_label(i), d_data[i]);
|
||||
}
|
||||
|
||||
info.SetInfo("weight", str.c_str());
|
||||
info.SetInfo(ctx, "weight", str.c_str());
|
||||
auto const& h_weight = info.weights_.HostVector();
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
ASSERT_EQ(h_weight[i], d_data[i]);
|
||||
}
|
||||
|
||||
info.SetInfo("base_margin", str.c_str());
|
||||
info.SetInfo(ctx, "base_margin", str.c_str());
|
||||
auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId);
|
||||
ASSERT_EQ(h_base_margin.Size(), d_data.size());
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
@@ -77,7 +77,7 @@ TEST(MetaInfo, FromInterface) {
|
||||
d_group_data[1] = 3;
|
||||
d_group_data[2] = 2;
|
||||
d_group_data[3] = 1;
|
||||
info.SetInfo("group", group_str.c_str());
|
||||
info.SetInfo(ctx, "group", group_str.c_str());
|
||||
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
|
||||
EXPECT_EQ(info.group_ptr_, expected_group_ptr);
|
||||
}
|
||||
@@ -89,10 +89,11 @@ TEST(MetaInfo, GPUStridedData) {
|
||||
TEST(MetaInfo, Group) {
|
||||
cudaSetDevice(0);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
|
||||
thrust::device_vector<uint32_t> d_uint;
|
||||
std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint);
|
||||
info.SetInfo("group", uint_str.c_str());
|
||||
info.SetInfo(ctx, "group", uint_str.c_str());
|
||||
auto& h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
@@ -102,7 +103,7 @@ TEST(MetaInfo, Group) {
|
||||
thrust::device_vector<int64_t> d_int64;
|
||||
std::string int_str = PrepareData<int64_t>("<i8", &d_int64);
|
||||
info = MetaInfo();
|
||||
info.SetInfo("group", int_str.c_str());
|
||||
info.SetInfo(ctx, "group", int_str.c_str());
|
||||
h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
@@ -113,11 +114,12 @@ TEST(MetaInfo, Group) {
|
||||
thrust::device_vector<float> d_float;
|
||||
std::string float_str = PrepareData<float>("<f4", &d_float);
|
||||
info = MetaInfo();
|
||||
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
||||
EXPECT_ANY_THROW(info.SetInfo(ctx, "group", float_str.c_str()));
|
||||
}
|
||||
|
||||
TEST(MetaInfo, GPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
Context ctx;
|
||||
info.num_row_ = 100;
|
||||
thrust::device_vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
@@ -127,7 +129,7 @@ TEST(MetaInfo, GPUQid) {
|
||||
Json array{std::vector<Json>{column}};
|
||||
std::string array_str;
|
||||
Json::Dump(array, &array_str);
|
||||
info.SetInfo("qid", array_str.c_str());
|
||||
info.SetInfo(ctx, "qid", array_str.c_str());
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
@@ -142,11 +144,12 @@ TEST(MetaInfo, DeviceExtend) {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
size_t const kRows = 100;
|
||||
MetaInfo lhs, rhs;
|
||||
Context ctx;
|
||||
|
||||
thrust::device_vector<float> d_data;
|
||||
std::string str = PrepareData<float>("<f4", &d_data, kRows);
|
||||
lhs.SetInfo("label", str.c_str());
|
||||
rhs.SetInfo("label", str.c_str());
|
||||
lhs.SetInfo(ctx, "label", str.c_str());
|
||||
rhs.SetInfo(ctx, "label", str.c_str());
|
||||
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.num_row_ = kRows;
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
inline void TestMetaInfoStridedData(int32_t device) {
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
|
||||
{
|
||||
// labels
|
||||
linalg::Tensor<float, 3> labels;
|
||||
@@ -25,7 +27,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_labels.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)});
|
||||
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
|
||||
auto const& h_result = info.labels.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_labels = labels.View(-1);
|
||||
@@ -46,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
std::iota(h_qid.begin(), h_qid.end(), 0);
|
||||
auto s = qid.View(device).Slice(linalg::All(), 0);
|
||||
auto str = ArrayInterfaceStr(s);
|
||||
info.SetInfo("qid", StringView{str});
|
||||
info.SetInfo(ctx, "qid", StringView{str});
|
||||
auto const& h_result = info.group_ptr_;
|
||||
ASSERT_EQ(h_result.size(), s.Size() + 1);
|
||||
}
|
||||
@@ -59,7 +61,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_margin.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||
auto const& h_result = info.base_margin_.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(-1);
|
||||
|
||||
@@ -257,7 +257,7 @@ TEST(Dart, Prediction) {
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("booster", "dart");
|
||||
|
||||
@@ -74,11 +74,9 @@ TEST(Learner, CheckGroup) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
|
||||
p_mat->Info().SetInfo(
|
||||
"weight", static_cast<void*>(weight.data()), DataType::kFloat32, kNumGroups);
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
|
||||
p_mat->SetInfo("weight", static_cast<void *>(weight.data()), DataType::kFloat32, kNumGroups);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
|
||||
|
||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
@@ -88,7 +86,7 @@ TEST(Learner, CheckGroup) {
|
||||
group.resize(kNumGroups+1);
|
||||
group[3] = 4;
|
||||
group[4] = 1;
|
||||
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
|
||||
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
|
||||
}
|
||||
|
||||
@@ -105,7 +103,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
|
||||
for (size_t i = 0; i < num_row; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
dmat->Info().SetInfo("label", labels.data(), DataType::kFloat32, num_row);
|
||||
dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row);
|
||||
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
learner->SetParams(Args{{"objective", "binary:logistic"}});
|
||||
|
||||
Reference in New Issue
Block a user