diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 4679ef543..c2a69a204 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -42,6 +42,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { out->Info() = this->Info().Slice(ridxs); out->Info().num_nonzero_ = h_offset.back(); } + out->ctx_ = this->ctx_; return out; } diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 266115731..198663872 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -1,13 +1,19 @@ -// Copyright by Contributors +/** + * Copyright 2016-2023 by XGBoost Contributors + */ #include -#include +#include // std::array +#include // std::numeric_limits +#include // std::unique_ptr -#include "../../../src/data/adapter.h" -#include "../../../src/data/simple_dmatrix.h" -#include "../filesystem.h" // dmlc::TemporaryDirectory -#include "../helpers.h" +#include "../../../src/data/adapter.h" // ArrayAdapter +#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix +#include "../filesystem.h" // dmlc::TemporaryDirectory +#include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData #include "xgboost/base.h" +#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "xgboost/string_view.h" // StringView using namespace xgboost; // NOLINT @@ -298,6 +304,17 @@ TEST(SimpleDMatrix, Slice) { ASSERT_EQ(out->Info().num_col_, out->Info().num_col_); ASSERT_EQ(out->Info().num_row_, ridxs.size()); ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense + + { + HostDeviceVector data; + auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data); + auto adapter = data::ArrayAdapter{StringView{arr_str}}; + auto n_threads = 2; + std::unique_ptr p_fmat{ + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), n_threads, "")}; + std::unique_ptr slice{p_fmat->Slice(ridxs)}; + ASSERT_LE(slice->Ctx()->Threads(), n_threads); + } } TEST(SimpleDMatrix, SaveLoadBinary) {