Support linalg data structures in check device. (#9243)
This commit is contained in:
parent
fc8110ef79
commit
0cba2cdbb0
@ -7,14 +7,15 @@
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../collective/communicator.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/algorithm.h" // for StableSort
|
||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "../common/error_msg.h" // for InfInData
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
@ -35,6 +36,7 @@
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/linalg.h" // Vector
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h"
|
||||
#include "xgboost/version_config.h"
|
||||
@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
||||
}
|
||||
// uint info
|
||||
if (key == "group") {
|
||||
linalg::Tensor<bst_group_t, 1> t;
|
||||
linalg::Vector<bst_group_t> t;
|
||||
CopyTensorInfoImpl(ctx, arr, &t);
|
||||
auto const& h_groups = t.Data()->HostVector();
|
||||
group_ptr_.clear();
|
||||
@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
||||
data::ValidateQueryGroup(group_ptr_);
|
||||
return;
|
||||
}
|
||||
|
||||
// float info
|
||||
linalg::Tensor<float, 1> t;
|
||||
CopyTensorInfoImpl<1>(ctx, arr, &t);
|
||||
@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
||||
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
|
||||
<< "Data is resided on a different device than `gpu_id`. "
|
||||
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
||||
<< "`gpu_id` for XGBoost: " << device;
|
||||
}
|
||||
template <typename T, std::int32_t D>
|
||||
void CheckDevice(std::int32_t device, linalg::Tensor<T, D> const& v) {
|
||||
CheckDevice(device, *v.Data());
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void MetaInfo::Validate(std::int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
<< "Size of weights must equal to number of groups when ranking "
|
||||
"group is used.";
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
|
||||
return;
|
||||
}
|
||||
if (group_ptr_.size() != 0) {
|
||||
CHECK_EQ(group_ptr_.back(), num_row_)
|
||||
<< "Invalid group structure. Number of rows obtained from groups "
|
||||
"doesn't equal to actual number of rows given by data.";
|
||||
<< error::GroupSize() << "the actual number of rows given by data.";
|
||||
}
|
||||
auto check_device = [device](HostDeviceVector<float> const& v) {
|
||||
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
|
||||
<< "Data is resided on a different device than `gpu_id`. "
|
||||
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
||||
<< "`gpu_id` for XGBoost: " << device;
|
||||
};
|
||||
|
||||
if (weights_.Size() != 0) {
|
||||
CHECK_EQ(weights_.Size(), num_row_)
|
||||
<< "Size of weights must equal to number of rows.";
|
||||
check_device(weights_);
|
||||
CheckDevice(device, weights_);
|
||||
return;
|
||||
}
|
||||
if (labels.Size() != 0) {
|
||||
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
|
||||
check_device(*labels.Data());
|
||||
CheckDevice(device, labels);
|
||||
return;
|
||||
}
|
||||
if (labels_lower_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
|
||||
<< "Size of label_lower_bound must equal to number of rows.";
|
||||
check_device(labels_lower_bound_);
|
||||
CheckDevice(device, labels_lower_bound_);
|
||||
return;
|
||||
}
|
||||
if (feature_weights.Size() != 0) {
|
||||
CHECK_EQ(feature_weights.Size(), num_col_)
|
||||
<< "Size of feature_weights must equal to number of columns.";
|
||||
check_device(feature_weights);
|
||||
CheckDevice(device, feature_weights);
|
||||
}
|
||||
if (labels_upper_bound_.Size() != 0) {
|
||||
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
||||
<< "Size of label_upper_bound must equal to number of rows.";
|
||||
check_device(labels_upper_bound_);
|
||||
CheckDevice(device, labels_upper_bound_);
|
||||
return;
|
||||
}
|
||||
CHECK_LE(num_nonzero_, num_col_ * num_row_);
|
||||
if (base_margin_.Size() != 0) {
|
||||
CHECK_EQ(base_margin_.Size() % num_row_, 0)
|
||||
<< "Size of base margin must be a multiple of number of rows.";
|
||||
check_device(*base_margin_.Data());
|
||||
CheckDevice(device, base_margin_);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
|
||||
auto& h_offset = this->offset.HostVector();
|
||||
auto& h_data = this->data.HostVector();
|
||||
n_threads = std::max(std::min(static_cast<std::size_t>(n_threads), this->Size()),
|
||||
static_cast<std::size_t>(1));
|
||||
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
|
||||
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
|
||||
auto beg = h_offset[i];
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user