Fix threads in DMatrix slice. (#8667)
This commit is contained in:
parent
e27cda7626
commit
07cf3d3e53
@ -42,6 +42,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
|||||||
out->Info() = this->Info().Slice(ridxs);
|
out->Info() = this->Info().Slice(ridxs);
|
||||||
out->Info().num_nonzero_ = h_offset.back();
|
out->Info().num_nonzero_ = h_offset.back();
|
||||||
}
|
}
|
||||||
|
out->ctx_ = this->ctx_;
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,19 @@
|
|||||||
// Copyright by Contributors
|
/**
|
||||||
|
* Copyright 2016-2023 by XGBoost Contributors
|
||||||
|
*/
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array> // std::array
|
||||||
|
#include <limits> // std::numeric_limits
|
||||||
|
#include <memory> // std::unique_ptr
|
||||||
|
|
||||||
#include "../../../src/data/adapter.h"
|
#include "../../../src/data/adapter.h" // ArrayAdapter
|
||||||
#include "../../../src/data/simple_dmatrix.h"
|
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
|
||||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||||
#include "../helpers.h"
|
#include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
|
#include "xgboost/string_view.h" // StringView
|
||||||
|
|
||||||
using namespace xgboost; // NOLINT
|
using namespace xgboost; // NOLINT
|
||||||
|
|
||||||
@ -299,6 +305,17 @@ TEST(SimpleDMatrix, Slice) {
|
|||||||
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
|
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
|
||||||
ASSERT_EQ(out->Info().num_row_, ridxs.size());
|
ASSERT_EQ(out->Info().num_row_, ridxs.size());
|
||||||
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
|
||||||
|
|
||||||
|
{
|
||||||
|
HostDeviceVector<float> data;
|
||||||
|
auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data);
|
||||||
|
auto adapter = data::ArrayAdapter{StringView{arr_str}};
|
||||||
|
auto n_threads = 2;
|
||||||
|
std::unique_ptr<DMatrix> p_fmat{
|
||||||
|
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), n_threads, "")};
|
||||||
|
std::unique_ptr<DMatrix> slice{p_fmat->Slice(ridxs)};
|
||||||
|
ASSERT_LE(slice->Ctx()->Threads(), n_threads);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SimpleDMatrix, SliceCol) {
|
TEST(SimpleDMatrix, SliceCol) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user