Convert labels into tensor. (#7456)
* Add a new ctor to tensor for `initilizer_list`. * Change labels from host device vector to tensor. * Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
@@ -204,8 +204,8 @@ class SerializationTest : public ::testing::Test {
|
||||
void SetUp() override {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
p_dmat_->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat_->Info().labels_.HostVector();
|
||||
p_dmat_->Info().labels.Reshape(kRows);
|
||||
auto& h_labels = p_dmat_->Info().labels.Data()->HostVector();
|
||||
|
||||
xgboost::SimpleLCG gen(0);
|
||||
SimpleRealUniformDistribution<float> dis(0.0f, 1.0f);
|
||||
@@ -219,6 +219,9 @@ class SerializationTest : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
size_t constexpr SerializationTest::kRows;
|
||||
size_t constexpr SerializationTest::kCols;
|
||||
|
||||
TEST_F(SerializationTest, Exact) {
|
||||
TestLearnerSerialization({{"booster", "gbtree"},
|
||||
{"seed", "0"},
|
||||
@@ -389,8 +392,8 @@ class LogitSerializationTest : public SerializationTest {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{p_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels_.HostVector();
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
auto& h_labels = p_dmat->Info().labels.Data()->HostVector();
|
||||
|
||||
std::bernoulli_distribution flip(0.5);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
@@ -513,8 +516,8 @@ class MultiClassesSerializationTest : public SerializationTest {
|
||||
p_dmat_ = RandomDataGenerator(kRows, kCols, .5f).GenerateDMatrix();
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{p_dmat_};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels_.HostVector();
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
auto &h_labels = p_dmat->Info().labels.Data()->HostVector();
|
||||
|
||||
std::uniform_int_distribution<size_t> categorical(0, kClasses - 1);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
Reference in New Issue
Block a user