Use context in SetInfo. (#7687)

* Use the name `Context`.
* Pass a context object into `SetInfo`.
* Add context to proxy matrix.
* Add context to iterative DMatrix.

This is to remove the use of the default number of threads during `SetInfo` as a follow-up on
removing the global omp variable while preparing for CUDA stream semantic.  Currently, XGBoost
uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
Jiaming Yuan 2022-03-24 22:16:26 +08:00 committed by GitHub
parent f5b20286e2
commit 64575591d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 142 additions and 142 deletions

View File

@ -148,13 +148,13 @@ class MetaInfo {
* \param dtype The type of the source data. * \param dtype The type of the source data.
* \param num Number of elements in the source array. * \param num Number of elements in the source array.
*/ */
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num); void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
/*! /*!
* \brief Set information in the meta info with array interface. * \brief Set information in the meta info with array interface.
* \param key The key of the information. * \param key The key of the information.
* \param interface_str String representation of json format array interface. * \param interface_str String representation of json format array interface.
*/ */
void SetInfo(StringView key, StringView interface_str); void SetInfo(Context const& ctx, StringView key, StringView interface_str);
void GetInfo(char const* key, bst_ulong* out_len, DataType dtype, void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const; const void** out_dptr) const;
@ -176,8 +176,8 @@ class MetaInfo {
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column); void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
private: private:
void SetInfoFromHost(StringView key, Json arr); void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(StringView key, Json arr); void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
/*! \brief argsort of labels */ /*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache_; mutable std::vector<size_t> label_order_cache_;
@ -478,12 +478,13 @@ class DMatrix {
DMatrix() = default; DMatrix() = default;
/*! \brief meta information of the dataset */ /*! \brief meta information of the dataset */
virtual MetaInfo& Info() = 0; virtual MetaInfo& Info() = 0;
virtual void SetInfo(const char *key, const void *dptr, DataType dtype, virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
size_t num) { auto const& ctx = *this->Ctx();
this->Info().SetInfo(key, dptr, dtype, num); this->Info().SetInfo(ctx, key, dptr, dtype, num);
} }
virtual void SetInfo(const char* key, std::string const& interface_str) { virtual void SetInfo(const char* key, std::string const& interface_str) {
this->Info().SetInfo(key, StringView{interface_str}); auto const& ctx = *this->Ctx();
this->Info().SetInfo(ctx, key, StringView{interface_str});
} }
/*! \brief meta information of the dataset */ /*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0; virtual const MetaInfo& Info() const = 0;
@ -494,7 +495,7 @@ class DMatrix {
* \brief Get the context object of this DMatrix. The context is created during construction of * \brief Get the context object of this DMatrix. The context is created during construction of
* DMatrix with user specified `nthread` parameter. * DMatrix with user specified `nthread` parameter.
*/ */
virtual GenericParameter const* Ctx() const = 0; virtual Context const* Ctx() const = 0;
/** /**
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches. * \brief Gets batches. Use range based for loop over BatchSet to access individual batches.

View File

@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
.describe("Enable checking whether parameters are used or not."); .describe("Enable checking whether parameters are used or not.");
} }
}; };
using Context = GenericParameter;
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_GENERIC_PARAMETERS_H_ #endif // XGBOOST_GENERIC_PARAMETERS_H_

View File

@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
const char* field,
const bst_float* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len); p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
char const* field, char const *interface_c_str) {
char const* interface_c_str) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
->get()->Info().SetInfo(field, interface_c_str); p_fmat->SetInfo(field, interface_c_str);
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
const char* field,
const unsigned* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len); p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
API_END(); API_END();
} }
@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
void const *data, xgboost::bst_ulong size, xgboost::bst_ulong size, int type) {
int type) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info(); auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
CHECK(type >= 1 && type <= 4); CHECK(type >= 1 && type <= 4);
info.SetInfo(field, data, static_cast<DataType>(type), size); p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
const unsigned* group,
xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead."; LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
static_cast<std::shared_ptr<DMatrix>*>(handle) auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len); p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
API_END(); API_END();
} }

View File

@ -409,7 +409,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
namespace { namespace {
template <int32_t D, typename T> template <int32_t D, typename T>
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) { void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
ArrayInterface<D> array{arr_interface}; ArrayInterface<D> array{arr_interface};
if (array.n == 0) { if (array.n == 0) {
p_out->Reshape(array.shape); p_out->Reshape(array.shape);
@ -428,16 +428,15 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
return; return;
} }
p_out->Reshape(array.shape); p_out->Reshape(array.shape);
auto t = p_out->View(GenericParameter::kCpuId); auto t = p_out->View(Context::kCpuId);
CHECK(t.CContiguous()); CHECK(t.CContiguous());
// FIXME(jiamingy): Remove the use of this default thread. linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape())); return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
}); });
} }
} // namespace } // namespace
void MetaInfo::SetInfo(StringView key, StringView interface_str) { void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
Json j_interface = Json::Load(interface_str); Json j_interface = Json::Load(interface_str);
bool is_cuda{false}; bool is_cuda{false};
if (IsA<Array>(j_interface)) { if (IsA<Array>(j_interface)) {
@ -454,16 +453,16 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
} }
if (is_cuda) { if (is_cuda) {
this->SetInfoFromCUDA(key, j_interface); this->SetInfoFromCUDA(ctx, key, j_interface);
} else { } else {
this->SetInfoFromHost(key, j_interface); this->SetInfoFromHost(ctx, key, j_interface);
} }
} }
void MetaInfo::SetInfoFromHost(StringView key, Json arr) { void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
// multi-dim float info // multi-dim float info
if (key == "base_margin") { if (key == "base_margin") {
CopyTensorInfoImpl(arr, &this->base_margin_); CopyTensorInfoImpl(ctx, arr, &this->base_margin_);
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of // FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
// input shape. This issue is CPU only since CUDA uses array interface from day 1. // input shape. This issue is CPU only since CUDA uses array interface from day 1.
// //
@ -477,7 +476,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
} }
return; return;
} else if (key == "label") { } else if (key == "label") {
CopyTensorInfoImpl(arr, &this->labels); CopyTensorInfoImpl(ctx, arr, &this->labels);
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) { if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels."; CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
size_t n_targets = this->labels.Size() / this->num_row_; size_t n_targets = this->labels.Size() / this->num_row_;
@ -491,7 +490,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
// uint info // uint info
if (key == "group") { if (key == "group") {
linalg::Tensor<bst_group_t, 1> t; linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t); CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector(); auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear(); group_ptr_.clear();
group_ptr_.resize(h_groups.size() + 1, 0); group_ptr_.resize(h_groups.size() + 1, 0);
@ -501,7 +500,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
return; return;
} else if (key == "qid") { } else if (key == "qid") {
linalg::Tensor<bst_group_t, 1> t; linalg::Tensor<bst_group_t, 1> t;
CopyTensorInfoImpl(arr, &t); CopyTensorInfoImpl(ctx, arr, &t);
bool non_dec = true; bool non_dec = true;
auto const& query_ids = t.Data()->HostVector(); auto const& query_ids = t.Data()->HostVector();
for (size_t i = 1; i < query_ids.size(); ++i) { for (size_t i = 1; i < query_ids.size(); ++i) {
@ -526,7 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
} }
// float info // float info
linalg::Tensor<float, 1> t; linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(arr, &t); CopyTensorInfoImpl<1>(ctx, arr, &t);
if (key == "weight") { if (key == "weight") {
this->weights_ = std::move(*t.Data()); this->weights_ = std::move(*t.Data());
auto const& h_weights = this->weights_.ConstHostVector(); auto const& h_weights = this->weights_.ConstHostVector();
@ -548,13 +547,15 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
} }
} }
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
auto proc = [&](auto cast_d_ptr) { auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>; using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, GenericParameter::kCpuId);
CHECK(t.CContiguous()); CHECK(t.CContiguous());
Json interface { linalg::ArrayInterface(t) }; Json interface {
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous); assert(ArrayInterface<1>{interface}.is_contiguous);
return interface; return interface;
}; };
@ -562,22 +563,22 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
switch (dtype) { switch (dtype) {
case xgboost::DataType::kFloat32: { case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr); auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr)); this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break; break;
} }
case xgboost::DataType::kDouble: { case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr); auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr)); this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break; break;
} }
case xgboost::DataType::kUInt32: { case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr); auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr)); this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break; break;
} }
case xgboost::DataType::kUInt64: { case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr); auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(key, proc(cast_ptr)); this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break; break;
} }
default: default:
@ -724,9 +725,7 @@ void MetaInfo::Validate(int32_t device) const {
"doesn't equal to actual number of rows given by data."; "doesn't equal to actual number of rows given by data.";
} }
auto check_device = [device](HostDeviceVector<float> const& v) { auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == GenericParameter::kCpuId || CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
device == GenericParameter::kCpuId ||
v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. " << "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", " << "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device; << "`gpu_id` for XGBoost: " << device;
@ -769,7 +768,9 @@ void MetaInfo::Validate(int32_t device) const {
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); } void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json arr) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
using DMatrixThreadLocal = using DMatrixThreadLocal =

View File

@ -115,7 +115,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
} }
} // namespace } // namespace
void MetaInfo::SetInfoFromCUDA(StringView key, Json array) { // Context is not used until we have CUDA stream.
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
// multi-dim float info // multi-dim float info
if (key == "base_margin") { if (key == "base_margin") {
CopyTensorInfoImpl(array, &base_margin_); CopyTensorInfoImpl(array, &base_margin_);

View File

@ -43,18 +43,18 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t batches = 0; size_t batches = 0;
size_t accumulated_rows = 0; size_t accumulated_rows = 0;
bst_feature_t cols = 0; bst_feature_t cols = 0;
int32_t device = GenericParameter::kCpuId;
int32_t current_device; int32_t current_device;
dh::safe_cuda(cudaGetDevice(&current_device)); dh::safe_cuda(cudaGetDevice(&current_device));
auto get_device = [&]() -> int32_t { auto get_device = [&]() -> int32_t {
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device; int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
CHECK_NE(d, GenericParameter::kCpuId); CHECK_NE(d, Context::kCpuId);
return d; return d;
}; };
while (iter.Next()) { while (iter.Next()) {
device = proxy->DeviceIdx(); ctx_.gpu_id = proxy->DeviceIdx();
CHECK_LT(device, common::AllVisibleGPUs()); CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
dh::safe_cuda(cudaSetDevice(get_device())); dh::safe_cuda(cudaSetDevice(get_device()));
if (cols == 0) { if (cols == 0) {
cols = num_cols(); cols = num_cols();

View File

@ -21,6 +21,7 @@ namespace data {
class IterativeDeviceDMatrix : public DMatrix { class IterativeDeviceDMatrix : public DMatrix {
MetaInfo info_; MetaInfo info_;
Context ctx_;
BatchParam batch_param_; BatchParam batch_param_;
std::shared_ptr<EllpackPage> page_; std::shared_ptr<EllpackPage> page_;
@ -72,10 +73,7 @@ class IterativeDeviceDMatrix : public DMatrix {
MetaInfo &Info() override { return info_; } MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; } MetaInfo const &Info() const override { return info_; }
GenericParameter const *Ctx() const override { Context const *Ctx() const override { return &ctx_; }
LOG(FATAL) << "`IterativeDMatrix` doesn't have context.";
return nullptr;
}
}; };
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2020 XGBoost contributors * Copyright 2020-2022, XGBoost contributors
*/ */
#include "proxy_dmatrix.h" #include "proxy_dmatrix.h"
#include "device_adapter.cuh" #include "device_adapter.cuh"
@ -11,10 +11,10 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)}; std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)};
auto const& value = adapter->Value(); auto const& value = adapter->Value();
this->batch_ = adapter; this->batch_ = adapter;
device_ = adapter->DeviceIdx(); ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns(); this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows(); this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) { if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0); CHECK_EQ(this->Info().num_row_, 0);
} }
} }
@ -22,13 +22,12 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
void DMatrixProxy::FromCudaArray(std::string interface_str) { void DMatrixProxy::FromCudaArray(std::string interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str)); std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
this->batch_ = adapter; this->batch_ = adapter;
device_ = adapter->DeviceIdx(); ctx_.gpu_id = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns(); this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows(); this->Info().num_row_ = adapter->NumRows();
if (device_ < 0) { if (ctx_.gpu_id < 0) {
CHECK_EQ(this->Info().num_row_, 0); CHECK_EQ(this->Info().num_row_, 0);
} }
} }
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2020-2021 XGBoost contributors * Copyright 2020-2022, XGBoost contributors
*/ */
#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_ #ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
#define XGBOOST_DATA_PROXY_DMATRIX_H_ #define XGBOOST_DATA_PROXY_DMATRIX_H_
@ -45,7 +45,7 @@ class DataIterProxy {
class DMatrixProxy : public DMatrix { class DMatrixProxy : public DMatrix {
MetaInfo info_; MetaInfo info_;
dmlc::any batch_; dmlc::any batch_;
int32_t device_ { xgboost::GenericParameter::kCpuId }; Context ctx_;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
void FromCudaColumnar(std::string interface_str); void FromCudaColumnar(std::string interface_str);
@ -53,7 +53,7 @@ class DMatrixProxy : public DMatrix {
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
public: public:
int DeviceIdx() const { return device_; } int DeviceIdx() const { return ctx_.gpu_id; }
void SetData(char const* c_interface) { void SetData(char const* c_interface) {
common::AssertGPUSupport(); common::AssertGPUSupport();
@ -67,7 +67,7 @@ class DMatrixProxy : public DMatrix {
this->FromCudaArray(interface_str); this->FromCudaArray(interface_str);
} }
if (this->info_.num_row_ == 0) { if (this->info_.num_row_ == 0) {
this->device_ = GenericParameter::kCpuId; this->ctx_.gpu_id = Context::kCpuId;
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} }
@ -79,10 +79,7 @@ class DMatrixProxy : public DMatrix {
MetaInfo& Info() override { return info_; } MetaInfo& Info() override { return info_; }
MetaInfo const& Info() const override { return info_; } MetaInfo const& Info() const override { return info_; }
GenericParameter const* Ctx() const override { Context const* Ctx() const override { return &ctx_; }
LOG(FATAL) << "`ProxyDMatrix` doesn't have context.";
return nullptr;
}
bool SingleColBlock() const override { return true; } bool SingleColBlock() const override { return true; }
bool EllpackExists() const override { return true; } bool EllpackExists() const override { return true; }

View File

@ -149,10 +149,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
} }
if (batch.BaseMargin() != nullptr) { if (batch.BaseMargin() != nullptr) {
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(), info_.base_margin_ = decltype(info_.base_margin_){
batch.BaseMargin() + batch.Size(), batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, Context::kCpuId};
{batch.Size()},
GenericParameter::kCpuId};
} }
if (batch.Qid() != nullptr) { if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size()); qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());

View File

@ -31,7 +31,7 @@ class SimpleDMatrix : public DMatrix {
MetaInfo& Info() override; MetaInfo& Info() override;
const MetaInfo& Info() const override; const MetaInfo& Info() const override;
GenericParameter const* Ctx() const override { return &ctx_; } Context const* Ctx() const override { return &ctx_; }
bool SingleColBlock() const override { return true; } bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override; DMatrix* Slice(common::Span<int32_t const> ridxs) override;
@ -63,7 +63,7 @@ class SimpleDMatrix : public DMatrix {
} }
private: private:
GenericParameter ctx_; Context ctx_;
}; };
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -69,7 +69,7 @@ class SparsePageDMatrix : public DMatrix {
XGDMatrixCallbackNext *next_; XGDMatrixCallbackNext *next_;
float missing_; float missing_;
GenericParameter ctx_; Context ctx_;
std::string cache_prefix_; std::string cache_prefix_;
uint32_t n_batches_ {0}; uint32_t n_batches_ {0};
// sparse page is the source to other page types, we make a special member function. // sparse page is the source to other page types, we make a special member function.
@ -100,7 +100,7 @@ class SparsePageDMatrix : public DMatrix {
MetaInfo& Info() override; MetaInfo& Info() override;
const MetaInfo& Info() const override; const MetaInfo& Info() const override;
GenericParameter const* Ctx() const override { return &ctx_; } Context const* Ctx() const override { return &ctx_; }
bool SingleColBlock() const override { return false; } bool SingleColBlock() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override { DMatrix *Slice(common::Span<int32_t const>) override {

View File

@ -149,8 +149,7 @@ TEST(CutsBuilder, SearchGroupInd) {
group[2] = 7; group[2] = 7;
group[3] = 5; group[3] = 5;
p_mat->Info().SetInfo( p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
"group", group.data(), DataType::kUInt32, kNumGroups);
HistogramCuts hmat; HistogramCuts hmat;
@ -350,6 +349,7 @@ void TestSketchFromWeights(bool with_group) {
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0)); common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
MetaInfo info; MetaInfo info;
Context ctx;
auto& h_weights = info.weights_.HostVector(); auto& h_weights = info.weights_.HostVector();
if (with_group) { if (with_group) {
h_weights.resize(kGroups); h_weights.resize(kGroups);
@ -363,7 +363,7 @@ void TestSketchFromWeights(bool with_group) {
for (size_t i = 0; i < kGroups; ++i) { for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups; groups[i] = kRows / kGroups;
} }
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
} }
info.num_row_ = kRows; info.num_row_ = kRows;
@ -371,10 +371,10 @@ void TestSketchFromWeights(bool with_group) {
// Assign weights. // Assign weights.
if (with_group) { if (with_group) {
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
} }
m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
m->Info().num_col_ = kCols; m->Info().num_col_ = kCols;
m->Info().num_row_ = kRows; m->Info().num_row_ = kRows;
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);

View File

@ -520,7 +520,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
for (size_t i = 0; i < kGroups; ++i) { for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups; groups[i] = kRows / kGroups;
} }
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0); HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
h_weights.clear(); h_weights.clear();
@ -550,6 +550,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
&storage); &storage);
MetaInfo info; MetaInfo info;
Context ctx;
auto& h_weights = info.weights_.HostVector(); auto& h_weights = info.weights_.HostVector();
if (with_group) { if (with_group) {
h_weights.resize(kGroups); h_weights.resize(kGroups);
@ -563,7 +564,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
for (size_t i = 0; i < kGroups; ++i) { for (size_t i = 0; i < kGroups; ++i) {
groups[i] = kRows / kGroups; groups[i] = kRows / kGroups;
} }
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
} }
info.weights_.SetDevice(0); info.weights_.SetDevice(0);
@ -582,10 +583,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
if (with_group) { if (with_group) {
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
} }
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size());
dmat->Info().num_col_ = kCols; dmat->Info().num_col_ = kCols;
dmat->Info().num_row_ = kRows; dmat->Info().num_row_ = kRows;
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);

View File

@ -12,28 +12,29 @@
#include "xgboost/base.h" #include "xgboost/base.h"
TEST(MetaInfo, GetSet) { TEST(MetaInfo, GetSet) {
xgboost::Context ctx;
xgboost::MetaInfo info; xgboost::MetaInfo info;
double double2[2] = {1.0, 2.0}; double double2[2] = {1.0, 2.0};
EXPECT_EQ(info.labels.Size(), 0); EXPECT_EQ(info.labels.Size(), 0);
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2); info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2);
EXPECT_EQ(info.labels.Size(), 2); EXPECT_EQ(info.labels.Size(), 2);
float float2[2] = {1.0f, 2.0f}; float float2[2] = {1.0f, 2.0f};
EXPECT_EQ(info.GetWeight(1), 1.0f) EXPECT_EQ(info.GetWeight(1), 1.0f)
<< "When no weights are given, was expecting default value 1"; << "When no weights are given, was expecting default value 1";
info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2); info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2);
EXPECT_EQ(info.GetWeight(1), 2.0f); EXPECT_EQ(info.GetWeight(1), 2.0f);
uint32_t uint32_t2[2] = {1U, 2U}; uint32_t uint32_t2[2] = {1U, 2U};
EXPECT_EQ(info.base_margin_.Size(), 0); EXPECT_EQ(info.base_margin_.Size(), 0);
info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2); info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
EXPECT_EQ(info.base_margin_.Size(), 2); EXPECT_EQ(info.base_margin_.Size(), 2);
uint64_t uint64_t2[2] = {1U, 2U}; uint64_t uint64_t2[2] = {1U, 2U};
EXPECT_EQ(info.group_ptr_.size(), 0); EXPECT_EQ(info.group_ptr_.size(), 0);
info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2); info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2);
ASSERT_EQ(info.group_ptr_.size(), 3); ASSERT_EQ(info.group_ptr_.size(), 3);
EXPECT_EQ(info.group_ptr_[2], 3); EXPECT_EQ(info.group_ptr_[2], 3);
@ -73,6 +74,8 @@ TEST(MetaInfo, GetSetFeature) {
TEST(MetaInfo, SaveLoadBinary) { TEST(MetaInfo, SaveLoadBinary) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
xgboost::Context ctx;
uint64_t constexpr kRows { 64 }, kCols { 32 }; uint64_t constexpr kRows { 64 }, kCols { 32 };
auto generator = []() { auto generator = []() {
static float f = 0; static float f = 0;
@ -80,9 +83,9 @@ TEST(MetaInfo, SaveLoadBinary) {
}; };
std::vector<float> values (kRows); std::vector<float> values (kRows);
std::generate(values.begin(), values.end(), generator); std::generate(values.begin(), values.end(), generator);
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows); info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows);
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows); info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows);
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows); info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
info.num_row_ = kRows; info.num_row_ = kRows;
info.num_col_ = kCols; info.num_col_ = kCols;
@ -210,13 +213,14 @@ TEST(MetaInfo, LoadQid) {
TEST(MetaInfo, CPUQid) { TEST(MetaInfo, CPUQid) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
xgboost::Context ctx;
info.num_row_ = 100; info.num_row_ = 100;
std::vector<uint32_t> qid(info.num_row_, 0); std::vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) { for (size_t i = 0; i < qid.size(); ++i) {
qid[i] = i; qid[i] = i;
} }
info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_); info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_); ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
@ -232,12 +236,15 @@ TEST(MetaInfo, Validate) {
info.num_nonzero_ = 12; info.num_nonzero_ = 12;
info.num_col_ = 3; info.num_col_ = 3;
std::vector<xgboost::bst_group_t> groups (11); std::vector<xgboost::bst_group_t> groups (11);
info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11); xgboost::Context ctx;
info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11);
EXPECT_THROW(info.Validate(0), dmlc::Error); EXPECT_THROW(info.Validate(0), dmlc::Error);
std::vector<float> labels(info.num_row_ + 1); std::vector<float> labels(info.num_row_ + 1);
EXPECT_THROW( EXPECT_THROW(
{ info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); }, {
info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
},
dmlc::Error); dmlc::Error);
// Make overflow data, which can happen when users pass group structure as int // Make overflow data, which can happen when users pass group structure as int
@ -247,14 +254,13 @@ TEST(MetaInfo, Validate) {
groups.push_back(1562500); groups.push_back(1562500);
} }
groups.push_back(static_cast<xgboost::bst_group_t>(-1)); groups.push_back(static_cast<xgboost::bst_group_t>(-1));
EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()),
groups.size()),
dmlc::Error); dmlc::Error);
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
info.group_ptr_.clear(); info.group_ptr_.clear();
labels.resize(info.num_row_); labels.resize(info.num_row_);
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
info.labels.SetDevice(0); info.labels.SetDevice(0);
EXPECT_THROW(info.Validate(1), dmlc::Error); EXPECT_THROW(info.Validate(1), dmlc::Error);
@ -263,12 +269,13 @@ TEST(MetaInfo, Validate) {
d_groups.DevicePointer(); // pull to device d_groups.DevicePointer(); // pull to device
std::string arr_interface_str{ArrayInterfaceStr( std::string arr_interface_str{ArrayInterfaceStr(
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))}; xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error); EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
} }
TEST(MetaInfo, HostExtend) { TEST(MetaInfo, HostExtend) {
xgboost::MetaInfo lhs, rhs; xgboost::MetaInfo lhs, rhs;
xgboost::Context ctx;
size_t const kRows = 100; size_t const kRows = 100;
lhs.labels.Reshape(kRows); lhs.labels.Reshape(kRows);
lhs.num_row_ = kRows; lhs.num_row_ = kRows;
@ -282,8 +289,8 @@ TEST(MetaInfo, HostExtend) {
for (size_t g = 0; g < kRows / per_group; ++g) { for (size_t g = 0; g < kRows / per_group; ++g) {
groups.emplace_back(per_group); groups.emplace_back(per_group);
} }
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size()); rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
lhs.Extend(rhs, true, true); lhs.Extend(rhs, true, true);
ASSERT_EQ(lhs.num_row_, kRows * 2); ASSERT_EQ(lhs.num_row_, kRows * 2);
@ -300,5 +307,5 @@ TEST(MetaInfo, HostExtend) {
} }
namespace xgboost { namespace xgboost {
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); } TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); }
} // namespace xgboost } // namespace xgboost

View File

@ -25,14 +25,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))}; std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape); column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))}); column["strides"] = Array(std::vector<Json>{Json(Integer{static_cast<Integer::Int>(sizeof(T))})});
column["version"] = 3; column["version"] = 3;
column["typestr"] = String(typestr); column["typestr"] = String(typestr);
auto p_d_data = d_data.data().get(); auto p_d_data = d_data.data().get();
std::vector<Json> j_data { std::vector<Json> j_data{Json(Integer{reinterpret_cast<Integer::Int>(p_d_data)}),
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))), Json(Boolean(false))};
Json(Boolean(false))};
column["data"] = j_data; column["data"] = j_data;
column["stream"] = nullptr; column["stream"] = nullptr;
Json array(std::vector<Json>{column}); Json array(std::vector<Json>{column});
@ -45,12 +44,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
TEST(MetaInfo, FromInterface) { TEST(MetaInfo, FromInterface) {
cudaSetDevice(0); cudaSetDevice(0);
Context ctx;
thrust::device_vector<float> d_data; thrust::device_vector<float> d_data;
std::string str = PrepareData<float>("<f4", &d_data); std::string str = PrepareData<float>("<f4", &d_data);
MetaInfo info; MetaInfo info;
info.SetInfo("label", str.c_str()); info.SetInfo(ctx, "label", str.c_str());
auto const& h_label = info.labels.HostView(); auto const& h_label = info.labels.HostView();
ASSERT_EQ(h_label.Size(), d_data.size()); ASSERT_EQ(h_label.Size(), d_data.size());
@ -58,13 +58,13 @@ TEST(MetaInfo, FromInterface) {
ASSERT_EQ(h_label(i), d_data[i]); ASSERT_EQ(h_label(i), d_data[i]);
} }
info.SetInfo("weight", str.c_str()); info.SetInfo(ctx, "weight", str.c_str());
auto const& h_weight = info.weights_.HostVector(); auto const& h_weight = info.weights_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) { for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_weight[i], d_data[i]); ASSERT_EQ(h_weight[i], d_data[i]);
} }
info.SetInfo("base_margin", str.c_str()); info.SetInfo(ctx, "base_margin", str.c_str());
auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId); auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId);
ASSERT_EQ(h_base_margin.Size(), d_data.size()); ASSERT_EQ(h_base_margin.Size(), d_data.size());
for (size_t i = 0; i < d_data.size(); ++i) { for (size_t i = 0; i < d_data.size(); ++i) {
@ -77,7 +77,7 @@ TEST(MetaInfo, FromInterface) {
d_group_data[1] = 3; d_group_data[1] = 3;
d_group_data[2] = 2; d_group_data[2] = 2;
d_group_data[3] = 1; d_group_data[3] = 1;
info.SetInfo("group", group_str.c_str()); info.SetInfo(ctx, "group", group_str.c_str());
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10}; std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
EXPECT_EQ(info.group_ptr_, expected_group_ptr); EXPECT_EQ(info.group_ptr_, expected_group_ptr);
} }
@ -89,10 +89,11 @@ TEST(MetaInfo, GPUStridedData) {
TEST(MetaInfo, Group) { TEST(MetaInfo, Group) {
cudaSetDevice(0); cudaSetDevice(0);
MetaInfo info; MetaInfo info;
Context ctx;
thrust::device_vector<uint32_t> d_uint; thrust::device_vector<uint32_t> d_uint;
std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint); std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint);
info.SetInfo("group", uint_str.c_str()); info.SetInfo(ctx, "group", uint_str.c_str());
auto& h_group = info.group_ptr_; auto& h_group = info.group_ptr_;
ASSERT_EQ(h_group.size(), d_uint.size() + 1); ASSERT_EQ(h_group.size(), d_uint.size() + 1);
for (size_t i = 1; i < h_group.size(); ++i) { for (size_t i = 1; i < h_group.size(); ++i) {
@ -102,7 +103,7 @@ TEST(MetaInfo, Group) {
thrust::device_vector<int64_t> d_int64; thrust::device_vector<int64_t> d_int64;
std::string int_str = PrepareData<int64_t>("<i8", &d_int64); std::string int_str = PrepareData<int64_t>("<i8", &d_int64);
info = MetaInfo(); info = MetaInfo();
info.SetInfo("group", int_str.c_str()); info.SetInfo(ctx, "group", int_str.c_str());
h_group = info.group_ptr_; h_group = info.group_ptr_;
ASSERT_EQ(h_group.size(), d_uint.size() + 1); ASSERT_EQ(h_group.size(), d_uint.size() + 1);
for (size_t i = 1; i < h_group.size(); ++i) { for (size_t i = 1; i < h_group.size(); ++i) {
@ -113,11 +114,12 @@ TEST(MetaInfo, Group) {
thrust::device_vector<float> d_float; thrust::device_vector<float> d_float;
std::string float_str = PrepareData<float>("<f4", &d_float); std::string float_str = PrepareData<float>("<f4", &d_float);
info = MetaInfo(); info = MetaInfo();
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str())); EXPECT_ANY_THROW(info.SetInfo(ctx, "group", float_str.c_str()));
} }
TEST(MetaInfo, GPUQid) { TEST(MetaInfo, GPUQid) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
Context ctx;
info.num_row_ = 100; info.num_row_ = 100;
thrust::device_vector<uint32_t> qid(info.num_row_, 0); thrust::device_vector<uint32_t> qid(info.num_row_, 0);
for (size_t i = 0; i < qid.size(); ++i) { for (size_t i = 0; i < qid.size(); ++i) {
@ -127,7 +129,7 @@ TEST(MetaInfo, GPUQid) {
Json array{std::vector<Json>{column}}; Json array{std::vector<Json>{column}};
std::string array_str; std::string array_str;
Json::Dump(array, &array_str); Json::Dump(array, &array_str);
info.SetInfo("qid", array_str.c_str()); info.SetInfo(ctx, "qid", array_str.c_str());
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1); ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
ASSERT_EQ(info.group_ptr_.front(), 0); ASSERT_EQ(info.group_ptr_.front(), 0);
ASSERT_EQ(info.group_ptr_.back(), info.num_row_); ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
@ -142,11 +144,12 @@ TEST(MetaInfo, DeviceExtend) {
dh::safe_cuda(cudaSetDevice(0)); dh::safe_cuda(cudaSetDevice(0));
size_t const kRows = 100; size_t const kRows = 100;
MetaInfo lhs, rhs; MetaInfo lhs, rhs;
Context ctx;
thrust::device_vector<float> d_data; thrust::device_vector<float> d_data;
std::string str = PrepareData<float>("<f4", &d_data, kRows); std::string str = PrepareData<float>("<f4", &d_data, kRows);
lhs.SetInfo("label", str.c_str()); lhs.SetInfo(ctx, "label", str.c_str());
rhs.SetInfo("label", str.c_str()); rhs.SetInfo(ctx, "label", str.c_str());
ASSERT_FALSE(rhs.labels.Data()->HostCanRead()); ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
lhs.num_row_ = kRows; lhs.num_row_ = kRows;
rhs.num_row_ = kRows; rhs.num_row_ = kRows;

View File

@ -16,6 +16,8 @@
namespace xgboost { namespace xgboost {
inline void TestMetaInfoStridedData(int32_t device) { inline void TestMetaInfoStridedData(int32_t device) {
MetaInfo info; MetaInfo info;
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
{ {
// labels // labels
linalg::Tensor<float, 3> labels; linalg::Tensor<float, 3> labels;
@ -25,7 +27,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All()); auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All());
ASSERT_EQ(t_labels.Shape().size(), 2); ASSERT_EQ(t_labels.Shape().size(), 2);
info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)}); info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
auto const& h_result = info.labels.View(-1); auto const& h_result = info.labels.View(-1);
ASSERT_EQ(h_result.Shape().size(), 2); ASSERT_EQ(h_result.Shape().size(), 2);
auto in_labels = labels.View(-1); auto in_labels = labels.View(-1);
@ -46,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
std::iota(h_qid.begin(), h_qid.end(), 0); std::iota(h_qid.begin(), h_qid.end(), 0);
auto s = qid.View(device).Slice(linalg::All(), 0); auto s = qid.View(device).Slice(linalg::All(), 0);
auto str = ArrayInterfaceStr(s); auto str = ArrayInterfaceStr(s);
info.SetInfo("qid", StringView{str}); info.SetInfo(ctx, "qid", StringView{str});
auto const& h_result = info.group_ptr_; auto const& h_result = info.group_ptr_;
ASSERT_EQ(h_result.size(), s.Size() + 1); ASSERT_EQ(h_result.size(), s.Size() + 1);
} }
@ -59,7 +61,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All()); auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
ASSERT_EQ(t_margin.Shape().size(), 2); ASSERT_EQ(t_margin.Shape().size(), 2);
info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)}); info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
auto const& h_result = info.base_margin_.View(-1); auto const& h_result = info.base_margin_.View(-1);
ASSERT_EQ(h_result.Shape().size(), 2); ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(-1); auto in_margin = base_margin.View(-1);

View File

@ -257,7 +257,7 @@ TEST(Dart, Prediction) {
for (size_t i = 0; i < kRows; ++i) { for (size_t i = 0; i < kRows; ++i) {
labels[i] = i % 2; labels[i] = i % 2;
} }
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows); p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows);
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat})); auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("booster", "dart"); learner->SetParam("booster", "dart");

View File

@ -74,11 +74,9 @@ TEST(Learner, CheckGroup) {
labels[i] = i % 2; labels[i] = i % 2;
} }
p_mat->Info().SetInfo( p_mat->SetInfo("weight", static_cast<void *>(weight.data()), DataType::kFloat32, kNumGroups);
"weight", static_cast<void*>(weight.data()), DataType::kFloat32, kNumGroups); p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
p_mat->Info().SetInfo( p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
"group", group.data(), DataType::kUInt32, kNumGroups);
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat}; std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat)); auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
@ -88,7 +86,7 @@ TEST(Learner, CheckGroup) {
group.resize(kNumGroups+1); group.resize(kNumGroups+1);
group[3] = 4; group[3] = 4;
group[4] = 1; group[4] = 1;
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1); p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat)); EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
} }
@ -105,7 +103,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
for (size_t i = 0; i < num_row; ++i) { for (size_t i = 0; i < num_row; ++i) {
labels[i] = i % 2; labels[i] = i % 2;
} }
dmat->Info().SetInfo("label", labels.data(), DataType::kFloat32, num_row); dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row);
std::vector<std::shared_ptr<DMatrix>> mat{dmat}; std::vector<std::shared_ptr<DMatrix>> mat{dmat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat)); auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->SetParams(Args{{"objective", "binary:logistic"}}); learner->SetParams(Args{{"objective", "binary:logistic"}});