From 07cf3d3e53f30cd7a62efecbe935d35689b9876e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 14 Jan 2023 07:16:57 +0800 Subject: [PATCH] Fix threads in DMatrix slice. (#8667) --- src/data/simple_dmatrix.cc | 1 + tests/cpp/data/test_simple_dmatrix.cc | 29 +++++++++++++++++++++------ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 808ecd8b3..28868da7d 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -42,6 +42,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { out->Info() = this->Info().Slice(ridxs); out->Info().num_nonzero_ = h_offset.back(); } + out->ctx_ = this->ctx_; return out; } diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 8fc8ff017..9d54751a7 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -1,13 +1,19 @@ -// Copyright by Contributors +/** + * Copyright 2016-2023 by XGBoost Contributors + */ #include -#include +#include // std::array +#include // std::numeric_limits +#include // std::unique_ptr -#include "../../../src/data/adapter.h" -#include "../../../src/data/simple_dmatrix.h" -#include "../filesystem.h" // dmlc::TemporaryDirectory -#include "../helpers.h" +#include "../../../src/data/adapter.h" // ArrayAdapter +#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix +#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData #include "xgboost/base.h" +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/string_view.h" // StringView using namespace xgboost; // NOLINT @@ -299,6 +305,17 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); ASSERT_EQ(out->Info().num_row_, ridxs.size()); ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense + + { + HostDeviceVector data; + auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data); + auto adapter = data::ArrayAdapter{StringView{arr_str}}; + auto n_threads = 2; + std::unique_ptr p_fmat{ + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), n_threads, "")}; + std::unique_ptr slice{p_fmat->Slice(ridxs)}; + ASSERT_LE(slice->Ctx()->Threads(), n_threads); + } } TEST(SimpleDMatrix, SliceCol) {