Use context in SetInfo. (#7687)
* Use the name `Context`. * Pass a context object into `SetInfo`. * Add context to proxy matrix. * Add context to iterative DMatrix. This is to remove the use of the default number of threads during `SetInfo` as a follow-up on removing the global omp variable while preparing for CUDA stream semantic. Currently, XGBoost uses the legacy CUDA stream, we will gradually remove them in the future in favor of non-blocking streams.
This commit is contained in:
parent
f5b20286e2
commit
64575591d8
@ -148,13 +148,13 @@ class MetaInfo {
|
||||
* \param dtype The type of the source data.
|
||||
* \param num Number of elements in the source array.
|
||||
*/
|
||||
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
|
||||
void SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype, size_t num);
|
||||
/*!
|
||||
* \brief Set information in the meta info with array interface.
|
||||
* \param key The key of the information.
|
||||
* \param interface_str String representation of json format array interface.
|
||||
*/
|
||||
void SetInfo(StringView key, StringView interface_str);
|
||||
void SetInfo(Context const& ctx, StringView key, StringView interface_str);
|
||||
|
||||
void GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
const void** out_dptr) const;
|
||||
@ -176,8 +176,8 @@ class MetaInfo {
|
||||
void Extend(MetaInfo const& that, bool accumulate_rows, bool check_column);
|
||||
|
||||
private:
|
||||
void SetInfoFromHost(StringView key, Json arr);
|
||||
void SetInfoFromCUDA(StringView key, Json arr);
|
||||
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
|
||||
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
|
||||
|
||||
/*! \brief argsort of labels */
|
||||
mutable std::vector<size_t> label_order_cache_;
|
||||
@ -478,12 +478,13 @@ class DMatrix {
|
||||
DMatrix() = default;
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual MetaInfo& Info() = 0;
|
||||
virtual void SetInfo(const char *key, const void *dptr, DataType dtype,
|
||||
size_t num) {
|
||||
this->Info().SetInfo(key, dptr, dtype, num);
|
||||
virtual void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, dptr, dtype, num);
|
||||
}
|
||||
virtual void SetInfo(const char* key, std::string const& interface_str) {
|
||||
this->Info().SetInfo(key, StringView{interface_str});
|
||||
auto const& ctx = *this->Ctx();
|
||||
this->Info().SetInfo(ctx, key, StringView{interface_str});
|
||||
}
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual const MetaInfo& Info() const = 0;
|
||||
@ -494,7 +495,7 @@ class DMatrix {
|
||||
* \brief Get the context object of this DMatrix. The context is created during construction of
|
||||
* DMatrix with user specified `nthread` parameter.
|
||||
*/
|
||||
virtual GenericParameter const* Ctx() const = 0;
|
||||
virtual Context const* Ctx() const = 0;
|
||||
|
||||
/**
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
|
||||
@ -75,6 +75,8 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
.describe("Enable checking whether parameters are used or not.");
|
||||
}
|
||||
};
|
||||
|
||||
using Context = GenericParameter;
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_GENERIC_PARAMETERS_H_
|
||||
|
||||
@ -485,35 +485,30 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const bst_float* info,
|
||||
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const bst_float *info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
|
||||
char const* field,
|
||||
char const* interface_c_str) {
|
||||
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *field,
|
||||
char const *interface_c_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, interface_c_str);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, interface_c_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const unsigned *info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -549,25 +544,22 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
|
||||
void const *data, xgboost::bst_ulong size,
|
||||
int type) {
|
||||
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void const *data,
|
||||
xgboost::bst_ulong size, int type) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
CHECK(type >= 1 && type <= 4);
|
||||
info.SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle, const unsigned *group, xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
LOG(WARNING) << "XGDMatrixSetGroup is deprecated, use `XGDMatrixSetUIntInfo` instead.";
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||
p_fmat->SetInfo("group", group, xgboost::DataType::kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -409,7 +409,7 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
||||
|
||||
namespace {
|
||||
template <int32_t D, typename T>
|
||||
void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
ArrayInterface<D> array{arr_interface};
|
||||
if (array.n == 0) {
|
||||
p_out->Reshape(array.shape);
|
||||
@ -428,16 +428,15 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(GenericParameter::kCpuId);
|
||||
auto t = p_out->View(Context::kCpuId);
|
||||
CHECK(t.CContiguous());
|
||||
// FIXME(jiamingy): Remove the use of this default thread.
|
||||
linalg::ElementWiseTransformHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
|
||||
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
|
||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MetaInfo::SetInfo(StringView key, StringView interface_str) {
|
||||
void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_str) {
|
||||
Json j_interface = Json::Load(interface_str);
|
||||
bool is_cuda{false};
|
||||
if (IsA<Array>(j_interface)) {
|
||||
@ -454,16 +453,16 @@ void MetaInfo::SetInfo(StringView key, StringView interface_str) {
|
||||
}
|
||||
|
||||
if (is_cuda) {
|
||||
this->SetInfoFromCUDA(key, j_interface);
|
||||
this->SetInfoFromCUDA(ctx, key, j_interface);
|
||||
} else {
|
||||
this->SetInfoFromHost(key, j_interface);
|
||||
this->SetInfoFromHost(ctx, key, j_interface);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
||||
// multi-dim float info
|
||||
if (key == "base_margin") {
|
||||
CopyTensorInfoImpl(arr, &this->base_margin_);
|
||||
CopyTensorInfoImpl(ctx, arr, &this->base_margin_);
|
||||
// FIXME(jiamingy): Remove the deprecated API and let all language bindings aware of
|
||||
// input shape. This issue is CPU only since CUDA uses array interface from day 1.
|
||||
//
|
||||
@ -477,7 +476,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
}
|
||||
return;
|
||||
} else if (key == "label") {
|
||||
CopyTensorInfoImpl(arr, &this->labels);
|
||||
CopyTensorInfoImpl(ctx, arr, &this->labels);
|
||||
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
|
||||
CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
|
||||
size_t n_targets = this->labels.Size() / this->num_row_;
|
||||
@ -491,7 +490,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
// uint info
|
||||
if (key == "group") {
|
||||
linalg::Tensor<bst_group_t, 1> t;
|
||||
CopyTensorInfoImpl(arr, &t);
|
||||
CopyTensorInfoImpl(ctx, arr, &t);
|
||||
auto const& h_groups = t.Data()->HostVector();
|
||||
group_ptr_.clear();
|
||||
group_ptr_.resize(h_groups.size() + 1, 0);
|
||||
@ -501,7 +500,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
return;
|
||||
} else if (key == "qid") {
|
||||
linalg::Tensor<bst_group_t, 1> t;
|
||||
CopyTensorInfoImpl(arr, &t);
|
||||
CopyTensorInfoImpl(ctx, arr, &t);
|
||||
bool non_dec = true;
|
||||
auto const& query_ids = t.Data()->HostVector();
|
||||
for (size_t i = 1; i < query_ids.size(); ++i) {
|
||||
@ -526,7 +525,7 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
}
|
||||
// float info
|
||||
linalg::Tensor<float, 1> t;
|
||||
CopyTensorInfoImpl<1>(arr, &t);
|
||||
CopyTensorInfoImpl<1>(ctx, arr, &t);
|
||||
if (key == "weight") {
|
||||
this->weights_ = std::move(*t.Data());
|
||||
auto const& h_weights = this->weights_.ConstHostVector();
|
||||
@ -548,13 +547,15 @@ void MetaInfo::SetInfoFromHost(StringView key, Json arr) {
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
|
||||
size_t num) {
|
||||
auto proc = [&](auto cast_d_ptr) {
|
||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||
auto t =
|
||||
linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, GenericParameter::kCpuId);
|
||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
|
||||
CHECK(t.CContiguous());
|
||||
Json interface { linalg::ArrayInterface(t) };
|
||||
Json interface {
|
||||
linalg::ArrayInterface(t)
|
||||
};
|
||||
assert(ArrayInterface<1>{interface}.is_contiguous);
|
||||
return interface;
|
||||
};
|
||||
@ -562,22 +563,22 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
switch (dtype) {
|
||||
case xgboost::DataType::kFloat32: {
|
||||
auto cast_ptr = reinterpret_cast<const float*>(dptr);
|
||||
this->SetInfoFromHost(key, proc(cast_ptr));
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kDouble: {
|
||||
auto cast_ptr = reinterpret_cast<const double*>(dptr);
|
||||
this->SetInfoFromHost(key, proc(cast_ptr));
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt32: {
|
||||
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
|
||||
this->SetInfoFromHost(key, proc(cast_ptr));
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
case xgboost::DataType::kUInt64: {
|
||||
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
|
||||
this->SetInfoFromHost(key, proc(cast_ptr));
|
||||
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -724,9 +725,7 @@ void MetaInfo::Validate(int32_t device) const {
|
||||
"doesn't equal to actual number of rows given by data.";
|
||||
}
|
||||
auto check_device = [device](HostDeviceVector<float> const& v) {
|
||||
CHECK(v.DeviceIdx() == GenericParameter::kCpuId ||
|
||||
device == GenericParameter::kCpuId ||
|
||||
v.DeviceIdx() == device)
|
||||
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
|
||||
<< "Data is resided on a different device than `gpu_id`. "
|
||||
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
||||
<< "`gpu_id` for XGBoost: " << device;
|
||||
@ -769,7 +768,9 @@ void MetaInfo::Validate(int32_t device) const {
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void MetaInfo::SetInfoFromCUDA(StringView key, Json arr) { common::AssertGPUSupport(); }
|
||||
void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json arr) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
using DMatrixThreadLocal =
|
||||
|
||||
@ -115,7 +115,8 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MetaInfo::SetInfoFromCUDA(StringView key, Json array) {
|
||||
// Context is not used until we have CUDA stream.
|
||||
void MetaInfo::SetInfoFromCUDA(Context const&, StringView key, Json array) {
|
||||
// multi-dim float info
|
||||
if (key == "base_margin") {
|
||||
CopyTensorInfoImpl(array, &base_margin_);
|
||||
|
||||
@ -43,18 +43,18 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t batches = 0;
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
int32_t device = GenericParameter::kCpuId;
|
||||
|
||||
int32_t current_device;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
auto get_device = [&]() -> int32_t {
|
||||
int32_t d = (device == GenericParameter::kCpuId) ? current_device : device;
|
||||
CHECK_NE(d, GenericParameter::kCpuId);
|
||||
int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
|
||||
CHECK_NE(d, Context::kCpuId);
|
||||
return d;
|
||||
};
|
||||
|
||||
while (iter.Next()) {
|
||||
device = proxy->DeviceIdx();
|
||||
CHECK_LT(device, common::AllVisibleGPUs());
|
||||
ctx_.gpu_id = proxy->DeviceIdx();
|
||||
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
|
||||
@ -21,6 +21,7 @@ namespace data {
|
||||
|
||||
class IterativeDeviceDMatrix : public DMatrix {
|
||||
MetaInfo info_;
|
||||
Context ctx_;
|
||||
BatchParam batch_param_;
|
||||
std::shared_ptr<EllpackPage> page_;
|
||||
|
||||
@ -72,10 +73,7 @@ class IterativeDeviceDMatrix : public DMatrix {
|
||||
MetaInfo &Info() override { return info_; }
|
||||
MetaInfo const &Info() const override { return info_; }
|
||||
|
||||
GenericParameter const *Ctx() const override {
|
||||
LOG(FATAL) << "`IterativeDMatrix` doesn't have context.";
|
||||
return nullptr;
|
||||
}
|
||||
Context const *Ctx() const override { return &ctx_; }
|
||||
};
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020 XGBoost contributors
|
||||
* Copyright 2020-2022, XGBoost contributors
|
||||
*/
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "device_adapter.cuh"
|
||||
@ -11,10 +11,10 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
|
||||
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)};
|
||||
auto const& value = adapter->Value();
|
||||
this->batch_ = adapter;
|
||||
device_ = adapter->DeviceIdx();
|
||||
ctx_.gpu_id = adapter->DeviceIdx();
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
if (device_ < 0) {
|
||||
if (ctx_.gpu_id < 0) {
|
||||
CHECK_EQ(this->Info().num_row_, 0);
|
||||
}
|
||||
}
|
||||
@ -22,13 +22,12 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
|
||||
void DMatrixProxy::FromCudaArray(std::string interface_str) {
|
||||
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
|
||||
this->batch_ = adapter;
|
||||
device_ = adapter->DeviceIdx();
|
||||
ctx_.gpu_id = adapter->DeviceIdx();
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
if (device_ < 0) {
|
||||
if (ctx_.gpu_id < 0) {
|
||||
CHECK_EQ(this->Info().num_row_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020-2021 XGBoost contributors
|
||||
* Copyright 2020-2022, XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
#define XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
@ -45,7 +45,7 @@ class DataIterProxy {
|
||||
class DMatrixProxy : public DMatrix {
|
||||
MetaInfo info_;
|
||||
dmlc::any batch_;
|
||||
int32_t device_ { xgboost::GenericParameter::kCpuId };
|
||||
Context ctx_;
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
void FromCudaColumnar(std::string interface_str);
|
||||
@ -53,7 +53,7 @@ class DMatrixProxy : public DMatrix {
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
public:
|
||||
int DeviceIdx() const { return device_; }
|
||||
int DeviceIdx() const { return ctx_.gpu_id; }
|
||||
|
||||
void SetData(char const* c_interface) {
|
||||
common::AssertGPUSupport();
|
||||
@ -67,7 +67,7 @@ class DMatrixProxy : public DMatrix {
|
||||
this->FromCudaArray(interface_str);
|
||||
}
|
||||
if (this->info_.num_row_ == 0) {
|
||||
this->device_ = GenericParameter::kCpuId;
|
||||
this->ctx_.gpu_id = Context::kCpuId;
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
@ -79,10 +79,7 @@ class DMatrixProxy : public DMatrix {
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
MetaInfo const& Info() const override { return info_; }
|
||||
GenericParameter const* Ctx() const override {
|
||||
LOG(FATAL) << "`ProxyDMatrix` doesn't have context.";
|
||||
return nullptr;
|
||||
}
|
||||
Context const* Ctx() const override { return &ctx_; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
bool EllpackExists() const override { return true; }
|
||||
|
||||
@ -149,10 +149,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size());
|
||||
}
|
||||
if (batch.BaseMargin() != nullptr) {
|
||||
info_.base_margin_ = decltype(info_.base_margin_){batch.BaseMargin(),
|
||||
batch.BaseMargin() + batch.Size(),
|
||||
{batch.Size()},
|
||||
GenericParameter::kCpuId};
|
||||
info_.base_margin_ = decltype(info_.base_margin_){
|
||||
batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, Context::kCpuId};
|
||||
}
|
||||
if (batch.Qid() != nullptr) {
|
||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||
|
||||
@ -31,7 +31,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
|
||||
MetaInfo& Info() override;
|
||||
const MetaInfo& Info() const override;
|
||||
GenericParameter const* Ctx() const override { return &ctx_; }
|
||||
Context const* Ctx() const override { return &ctx_; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
|
||||
@ -63,7 +63,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
}
|
||||
|
||||
private:
|
||||
GenericParameter ctx_;
|
||||
Context ctx_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -69,7 +69,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
XGDMatrixCallbackNext *next_;
|
||||
|
||||
float missing_;
|
||||
GenericParameter ctx_;
|
||||
Context ctx_;
|
||||
std::string cache_prefix_;
|
||||
uint32_t n_batches_ {0};
|
||||
// sparse page is the source to other page types, we make a special member function.
|
||||
@ -100,7 +100,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
|
||||
MetaInfo& Info() override;
|
||||
const MetaInfo& Info() const override;
|
||||
GenericParameter const* Ctx() const override { return &ctx_; }
|
||||
Context const* Ctx() const override { return &ctx_; }
|
||||
|
||||
bool SingleColBlock() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const>) override {
|
||||
|
||||
@ -149,8 +149,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
group[2] = 7;
|
||||
group[3] = 5;
|
||||
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
|
||||
HistogramCuts hmat;
|
||||
|
||||
@ -350,6 +349,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins, common::OmpGetNumThreads(0));
|
||||
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@ -363,7 +363,7 @@ void TestSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.num_row_ = kRows;
|
||||
@ -371,10 +371,10 @@ void TestSketchFromWeights(bool with_group) {
|
||||
|
||||
// Assign weights.
|
||||
if (with_group) {
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
m->Info().num_col_ = kCols;
|
||||
m->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
@ -520,7 +520,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
m->SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
|
||||
|
||||
h_weights.clear();
|
||||
@ -550,6 +550,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
|
||||
&storage);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
auto& h_weights = info.weights_.HostVector();
|
||||
if (with_group) {
|
||||
h_weights.resize(kGroups);
|
||||
@ -563,7 +564,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
for (size_t i = 0; i < kGroups; ++i) {
|
||||
groups[i] = kRows / kGroups;
|
||||
}
|
||||
info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
info.weights_.SetDevice(0);
|
||||
@ -582,10 +583,10 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
|
||||
dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups);
|
||||
}
|
||||
|
||||
dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size());
|
||||
dmat->Info().num_col_ = kCols;
|
||||
dmat->Info().num_row_ = kRows;
|
||||
ASSERT_EQ(cuts.Ptrs().size(), kCols + 1);
|
||||
|
||||
@ -12,28 +12,29 @@
|
||||
#include "xgboost/base.h"
|
||||
|
||||
TEST(MetaInfo, GetSet) {
|
||||
xgboost::Context ctx;
|
||||
xgboost::MetaInfo info;
|
||||
|
||||
double double2[2] = {1.0, 2.0};
|
||||
|
||||
EXPECT_EQ(info.labels.Size(), 0);
|
||||
info.SetInfo("label", double2, xgboost::DataType::kFloat32, 2);
|
||||
info.SetInfo(ctx, "label", double2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.labels.Size(), 2);
|
||||
|
||||
float float2[2] = {1.0f, 2.0f};
|
||||
EXPECT_EQ(info.GetWeight(1), 1.0f)
|
||||
<< "When no weights are given, was expecting default value 1";
|
||||
info.SetInfo("weight", float2, xgboost::DataType::kFloat32, 2);
|
||||
info.SetInfo(ctx, "weight", float2, xgboost::DataType::kFloat32, 2);
|
||||
EXPECT_EQ(info.GetWeight(1), 2.0f);
|
||||
|
||||
uint32_t uint32_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.base_margin_.Size(), 0);
|
||||
info.SetInfo("base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||
info.SetInfo(ctx, "base_margin", uint32_t2, xgboost::DataType::kUInt32, 2);
|
||||
EXPECT_EQ(info.base_margin_.Size(), 2);
|
||||
|
||||
uint64_t uint64_t2[2] = {1U, 2U};
|
||||
EXPECT_EQ(info.group_ptr_.size(), 0);
|
||||
info.SetInfo("group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||
info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2);
|
||||
ASSERT_EQ(info.group_ptr_.size(), 3);
|
||||
EXPECT_EQ(info.group_ptr_[2], 3);
|
||||
|
||||
@ -73,6 +74,8 @@ TEST(MetaInfo, GetSetFeature) {
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
|
||||
uint64_t constexpr kRows { 64 }, kCols { 32 };
|
||||
auto generator = []() {
|
||||
static float f = 0;
|
||||
@ -80,9 +83,9 @@ TEST(MetaInfo, SaveLoadBinary) {
|
||||
};
|
||||
std::vector<float> values (kRows);
|
||||
std::generate(values.begin(), values.end(), generator);
|
||||
info.SetInfo("label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo("base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "label", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "weight", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
info.SetInfo(ctx, "base_margin", values.data(), xgboost::DataType::kFloat32, kRows);
|
||||
|
||||
info.num_row_ = kRows;
|
||||
info.num_col_ = kCols;
|
||||
@ -210,13 +213,14 @@ TEST(MetaInfo, LoadQid) {
|
||||
|
||||
TEST(MetaInfo, CPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
info.num_row_ = 100;
|
||||
std::vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
qid[i] = i;
|
||||
}
|
||||
|
||||
info.SetInfo("qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
|
||||
info.SetInfo(ctx, "qid", qid.data(), xgboost::DataType::kUInt32, info.num_row_);
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
@ -232,12 +236,15 @@ TEST(MetaInfo, Validate) {
|
||||
info.num_nonzero_ = 12;
|
||||
info.num_col_ = 3;
|
||||
std::vector<xgboost::bst_group_t> groups (11);
|
||||
info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11);
|
||||
xgboost::Context ctx;
|
||||
info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, 11);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
std::vector<float> labels(info.num_row_ + 1);
|
||||
EXPECT_THROW(
|
||||
{ info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); },
|
||||
{
|
||||
info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||
},
|
||||
dmlc::Error);
|
||||
|
||||
// Make overflow data, which can happen when users pass group structure as int
|
||||
@ -247,14 +254,13 @@ TEST(MetaInfo, Validate) {
|
||||
groups.push_back(1562500);
|
||||
}
|
||||
groups.push_back(static_cast<xgboost::bst_group_t>(-1));
|
||||
EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32,
|
||||
groups.size()),
|
||||
EXPECT_THROW(info.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size()),
|
||||
dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
info.group_ptr_.clear();
|
||||
labels.resize(info.num_row_);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.SetInfo(ctx, "label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.labels.SetDevice(0);
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
|
||||
@ -263,12 +269,13 @@ TEST(MetaInfo, Validate) {
|
||||
d_groups.DevicePointer(); // pull to device
|
||||
std::string arr_interface_str{ArrayInterfaceStr(
|
||||
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
|
||||
EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||
EXPECT_THROW(info.SetInfo(ctx, "group", xgboost::StringView{arr_interface_str}), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
TEST(MetaInfo, HostExtend) {
|
||||
xgboost::MetaInfo lhs, rhs;
|
||||
xgboost::Context ctx;
|
||||
size_t const kRows = 100;
|
||||
lhs.labels.Reshape(kRows);
|
||||
lhs.num_row_ = kRows;
|
||||
@ -282,8 +289,8 @@ TEST(MetaInfo, HostExtend) {
|
||||
for (size_t g = 0; g < kRows / per_group; ++g) {
|
||||
groups.emplace_back(per_group);
|
||||
}
|
||||
lhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
rhs.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
lhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
rhs.SetInfo(ctx, "group", groups.data(), xgboost::DataType::kUInt32, groups.size());
|
||||
|
||||
lhs.Extend(rhs, true, true);
|
||||
ASSERT_EQ(lhs.num_row_, kRows * 2);
|
||||
@ -300,5 +307,5 @@ TEST(MetaInfo, HostExtend) {
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(GenericParameter::kCpuId); }
|
||||
TEST(MetaInfo, CPUStridedData) { TestMetaInfoStridedData(Context::kCpuId); }
|
||||
} // namespace xgboost
|
||||
|
||||
@ -25,14 +25,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
|
||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||
column["shape"] = Array(j_shape);
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(sizeof(T))))});
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer{static_cast<Integer::Int>(sizeof(T))})});
|
||||
column["version"] = 3;
|
||||
column["typestr"] = String(typestr);
|
||||
|
||||
auto p_d_data = d_data.data().get();
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
|
||||
Json(Boolean(false))};
|
||||
std::vector<Json> j_data{Json(Integer{reinterpret_cast<Integer::Int>(p_d_data)}),
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
column["stream"] = nullptr;
|
||||
Json array(std::vector<Json>{column});
|
||||
@ -45,12 +44,13 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
Context ctx;
|
||||
thrust::device_vector<float> d_data;
|
||||
|
||||
std::string str = PrepareData<float>("<f4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
info.SetInfo("label", str.c_str());
|
||||
info.SetInfo(ctx, "label", str.c_str());
|
||||
|
||||
auto const& h_label = info.labels.HostView();
|
||||
ASSERT_EQ(h_label.Size(), d_data.size());
|
||||
@ -58,13 +58,13 @@ TEST(MetaInfo, FromInterface) {
|
||||
ASSERT_EQ(h_label(i), d_data[i]);
|
||||
}
|
||||
|
||||
info.SetInfo("weight", str.c_str());
|
||||
info.SetInfo(ctx, "weight", str.c_str());
|
||||
auto const& h_weight = info.weights_.HostVector();
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
ASSERT_EQ(h_weight[i], d_data[i]);
|
||||
}
|
||||
|
||||
info.SetInfo("base_margin", str.c_str());
|
||||
info.SetInfo(ctx, "base_margin", str.c_str());
|
||||
auto const h_base_margin = info.base_margin_.View(GenericParameter::kCpuId);
|
||||
ASSERT_EQ(h_base_margin.Size(), d_data.size());
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
@ -77,7 +77,7 @@ TEST(MetaInfo, FromInterface) {
|
||||
d_group_data[1] = 3;
|
||||
d_group_data[2] = 2;
|
||||
d_group_data[3] = 1;
|
||||
info.SetInfo("group", group_str.c_str());
|
||||
info.SetInfo(ctx, "group", group_str.c_str());
|
||||
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
|
||||
EXPECT_EQ(info.group_ptr_, expected_group_ptr);
|
||||
}
|
||||
@ -89,10 +89,11 @@ TEST(MetaInfo, GPUStridedData) {
|
||||
TEST(MetaInfo, Group) {
|
||||
cudaSetDevice(0);
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
|
||||
thrust::device_vector<uint32_t> d_uint;
|
||||
std::string uint_str = PrepareData<uint32_t>("<u4", &d_uint);
|
||||
info.SetInfo("group", uint_str.c_str());
|
||||
info.SetInfo(ctx, "group", uint_str.c_str());
|
||||
auto& h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
@ -102,7 +103,7 @@ TEST(MetaInfo, Group) {
|
||||
thrust::device_vector<int64_t> d_int64;
|
||||
std::string int_str = PrepareData<int64_t>("<i8", &d_int64);
|
||||
info = MetaInfo();
|
||||
info.SetInfo("group", int_str.c_str());
|
||||
info.SetInfo(ctx, "group", int_str.c_str());
|
||||
h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_uint.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
@ -113,11 +114,12 @@ TEST(MetaInfo, Group) {
|
||||
thrust::device_vector<float> d_float;
|
||||
std::string float_str = PrepareData<float>("<f4", &d_float);
|
||||
info = MetaInfo();
|
||||
EXPECT_ANY_THROW(info.SetInfo("group", float_str.c_str()));
|
||||
EXPECT_ANY_THROW(info.SetInfo(ctx, "group", float_str.c_str()));
|
||||
}
|
||||
|
||||
TEST(MetaInfo, GPUQid) {
|
||||
xgboost::MetaInfo info;
|
||||
Context ctx;
|
||||
info.num_row_ = 100;
|
||||
thrust::device_vector<uint32_t> qid(info.num_row_, 0);
|
||||
for (size_t i = 0; i < qid.size(); ++i) {
|
||||
@ -127,7 +129,7 @@ TEST(MetaInfo, GPUQid) {
|
||||
Json array{std::vector<Json>{column}};
|
||||
std::string array_str;
|
||||
Json::Dump(array, &array_str);
|
||||
info.SetInfo("qid", array_str.c_str());
|
||||
info.SetInfo(ctx, "qid", array_str.c_str());
|
||||
ASSERT_EQ(info.group_ptr_.size(), info.num_row_ + 1);
|
||||
ASSERT_EQ(info.group_ptr_.front(), 0);
|
||||
ASSERT_EQ(info.group_ptr_.back(), info.num_row_);
|
||||
@ -142,11 +144,12 @@ TEST(MetaInfo, DeviceExtend) {
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
size_t const kRows = 100;
|
||||
MetaInfo lhs, rhs;
|
||||
Context ctx;
|
||||
|
||||
thrust::device_vector<float> d_data;
|
||||
std::string str = PrepareData<float>("<f4", &d_data, kRows);
|
||||
lhs.SetInfo("label", str.c_str());
|
||||
rhs.SetInfo("label", str.c_str());
|
||||
lhs.SetInfo(ctx, "label", str.c_str());
|
||||
rhs.SetInfo(ctx, "label", str.c_str());
|
||||
ASSERT_FALSE(rhs.labels.Data()->HostCanRead());
|
||||
lhs.num_row_ = kRows;
|
||||
rhs.num_row_ = kRows;
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
inline void TestMetaInfoStridedData(int32_t device) {
|
||||
MetaInfo info;
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", std::to_string(device)}});
|
||||
{
|
||||
// labels
|
||||
linalg::Tensor<float, 3> labels;
|
||||
@ -25,7 +27,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto t_labels = labels.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_labels.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("label", StringView{ArrayInterfaceStr(t_labels)});
|
||||
info.SetInfo(ctx, "label", StringView{ArrayInterfaceStr(t_labels)});
|
||||
auto const& h_result = info.labels.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_labels = labels.View(-1);
|
||||
@ -46,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
std::iota(h_qid.begin(), h_qid.end(), 0);
|
||||
auto s = qid.View(device).Slice(linalg::All(), 0);
|
||||
auto str = ArrayInterfaceStr(s);
|
||||
info.SetInfo("qid", StringView{str});
|
||||
info.SetInfo(ctx, "qid", StringView{str});
|
||||
auto const& h_result = info.group_ptr_;
|
||||
ASSERT_EQ(h_result.size(), s.Size() + 1);
|
||||
}
|
||||
@ -59,7 +61,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
|
||||
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
|
||||
ASSERT_EQ(t_margin.Shape().size(), 2);
|
||||
|
||||
info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||
info.SetInfo(ctx, "base_margin", StringView{ArrayInterfaceStr(t_margin)});
|
||||
auto const& h_result = info.base_margin_.View(-1);
|
||||
ASSERT_EQ(h_result.Shape().size(), 2);
|
||||
auto in_margin = base_margin.View(-1);
|
||||
|
||||
@ -257,7 +257,7 @@ TEST(Dart, Prediction) {
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||
learner->SetParam("booster", "dart");
|
||||
|
||||
@ -74,11 +74,9 @@ TEST(Learner, CheckGroup) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
|
||||
p_mat->Info().SetInfo(
|
||||
"weight", static_cast<void*>(weight.data()), DataType::kFloat32, kNumGroups);
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->Info().SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
|
||||
p_mat->SetInfo("weight", static_cast<void *>(weight.data()), DataType::kFloat32, kNumGroups);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
p_mat->SetInfo("label", labels.data(), DataType::kFloat32, kNumRows);
|
||||
|
||||
std::vector<std::shared_ptr<xgboost::DMatrix>> mat = {p_mat};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
@ -88,7 +86,7 @@ TEST(Learner, CheckGroup) {
|
||||
group.resize(kNumGroups+1);
|
||||
group[3] = 4;
|
||||
group[4] = 1;
|
||||
p_mat->Info().SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
|
||||
p_mat->SetInfo("group", group.data(), DataType::kUInt32, kNumGroups+1);
|
||||
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
|
||||
}
|
||||
|
||||
@ -105,7 +103,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
|
||||
for (size_t i = 0; i < num_row; ++i) {
|
||||
labels[i] = i % 2;
|
||||
}
|
||||
dmat->Info().SetInfo("label", labels.data(), DataType::kFloat32, num_row);
|
||||
dmat->SetInfo("label", labels.data(), DataType::kFloat32, num_row);
|
||||
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
learner->SetParams(Args{{"objective", "binary:logistic"}});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user