[backport] Fix threads in DMatrix slice. (#8667) (#8679)

This commit is contained in:
Jiaming Yuan 2023-01-14 18:46:04 +08:00 committed by GitHub
parent 10bb0a74ef
commit e5bef4ffce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 6 deletions

View File

@ -42,6 +42,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
out->Info() = this->Info().Slice(ridxs);
out->Info().num_nonzero_ = h_offset.back();
}
out->ctx_ = this->ctx_;
return out;
}

View File

@ -1,13 +1,19 @@
// Copyright by Contributors
/**
* Copyright 2016-2023 by XGBoost Contributors
*/
#include <xgboost/data.h>
#include <array>
#include <array> // std::array
#include <limits> // std::numeric_limits
#include <memory> // std::unique_ptr
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
#include "../../../src/data/adapter.h" // ArrayAdapter
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData
#include "xgboost/base.h"
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/string_view.h" // StringView
using namespace xgboost; // NOLINT
@ -298,6 +304,17 @@ TEST(SimpleDMatrix, Slice) {
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
ASSERT_EQ(out->Info().num_row_, ridxs.size());
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
{
HostDeviceVector<float> data;
auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data);
auto adapter = data::ArrayAdapter{StringView{arr_str}};
auto n_threads = 2;
std::unique_ptr<DMatrix> p_fmat{
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), n_threads, "")};
std::unique_ptr<DMatrix> slice{p_fmat->Slice(ridxs)};
ASSERT_LE(slice->Ctx()->Threads(), n_threads);
}
}
TEST(SimpleDMatrix, SaveLoadBinary) {