Support linalg data structures in check device. (#9243)
This commit is contained in:
parent
fc8110ef79
commit
0cba2cdbb0
@ -7,14 +7,15 @@
|
|||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
#include <cstddef>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../collective/communicator.h"
|
#include "../collective/communicator.h"
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/algorithm.h" // for StableSort
|
#include "../common/algorithm.h" // for StableSort
|
||||||
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||||
#include "../common/error_msg.h" // for InfInData
|
#include "../common/common.h"
|
||||||
|
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
|
||||||
#include "../common/group_data.h"
|
#include "../common/group_data.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/linalg_op.h"
|
#include "../common/linalg_op.h"
|
||||||
@ -35,6 +36,7 @@
|
|||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
|
#include "xgboost/linalg.h" // Vector
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/string_view.h"
|
#include "xgboost/string_view.h"
|
||||||
#include "xgboost/version_config.h"
|
#include "xgboost/version_config.h"
|
||||||
@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
|||||||
}
|
}
|
||||||
// uint info
|
// uint info
|
||||||
if (key == "group") {
|
if (key == "group") {
|
||||||
linalg::Tensor<bst_group_t, 1> t;
|
linalg::Vector<bst_group_t> t;
|
||||||
CopyTensorInfoImpl(ctx, arr, &t);
|
CopyTensorInfoImpl(ctx, arr, &t);
|
||||||
auto const& h_groups = t.Data()->HostVector();
|
auto const& h_groups = t.Data()->HostVector();
|
||||||
group_ptr_.clear();
|
group_ptr_.clear();
|
||||||
@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
|||||||
data::ValidateQueryGroup(group_ptr_);
|
data::ValidateQueryGroup(group_ptr_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float info
|
// float info
|
||||||
linalg::Tensor<float, 1> t;
|
linalg::Tensor<float, 1> t;
|
||||||
CopyTensorInfoImpl<1>(ctx, arr, &t);
|
CopyTensorInfoImpl<1>(ctx, arr, &t);
|
||||||
@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
||||||
|
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
|
||||||
|
<< "Data is resided on a different device than `gpu_id`. "
|
||||||
|
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
||||||
|
<< "`gpu_id` for XGBoost: " << device;
|
||||||
|
}
|
||||||
|
template <typename T, std::int32_t D>
|
||||||
|
void CheckDevice(std::int32_t device, linalg::Tensor<T, D> const& v) {
|
||||||
|
CheckDevice(device, *v.Data());
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
void MetaInfo::Validate(std::int32_t device) const {
|
void MetaInfo::Validate(std::int32_t device) const {
|
||||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
|
||||||
<< "Size of weights must equal to number of groups when ranking "
|
|
||||||
"group is used.";
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (group_ptr_.size() != 0) {
|
if (group_ptr_.size() != 0) {
|
||||||
CHECK_EQ(group_ptr_.back(), num_row_)
|
CHECK_EQ(group_ptr_.back(), num_row_)
|
||||||
<< "Invalid group structure. Number of rows obtained from groups "
|
<< error::GroupSize() << "the actual number of rows given by data.";
|
||||||
"doesn't equal to actual number of rows given by data.";
|
|
||||||
}
|
}
|
||||||
auto check_device = [device](HostDeviceVector<float> const& v) {
|
|
||||||
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
|
|
||||||
<< "Data is resided on a different device than `gpu_id`. "
|
|
||||||
<< "Device that data is on: " << v.DeviceIdx() << ", "
|
|
||||||
<< "`gpu_id` for XGBoost: " << device;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (weights_.Size() != 0) {
|
if (weights_.Size() != 0) {
|
||||||
CHECK_EQ(weights_.Size(), num_row_)
|
CHECK_EQ(weights_.Size(), num_row_)
|
||||||
<< "Size of weights must equal to number of rows.";
|
<< "Size of weights must equal to number of rows.";
|
||||||
check_device(weights_);
|
CheckDevice(device, weights_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (labels.Size() != 0) {
|
if (labels.Size() != 0) {
|
||||||
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
|
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
|
||||||
check_device(*labels.Data());
|
CheckDevice(device, labels);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (labels_lower_bound_.Size() != 0) {
|
if (labels_lower_bound_.Size() != 0) {
|
||||||
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
|
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
|
||||||
<< "Size of label_lower_bound must equal to number of rows.";
|
<< "Size of label_lower_bound must equal to number of rows.";
|
||||||
check_device(labels_lower_bound_);
|
CheckDevice(device, labels_lower_bound_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (feature_weights.Size() != 0) {
|
if (feature_weights.Size() != 0) {
|
||||||
CHECK_EQ(feature_weights.Size(), num_col_)
|
CHECK_EQ(feature_weights.Size(), num_col_)
|
||||||
<< "Size of feature_weights must equal to number of columns.";
|
<< "Size of feature_weights must equal to number of columns.";
|
||||||
check_device(feature_weights);
|
CheckDevice(device, feature_weights);
|
||||||
}
|
}
|
||||||
if (labels_upper_bound_.Size() != 0) {
|
if (labels_upper_bound_.Size() != 0) {
|
||||||
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
|
||||||
<< "Size of label_upper_bound must equal to number of rows.";
|
<< "Size of label_upper_bound must equal to number of rows.";
|
||||||
check_device(labels_upper_bound_);
|
CheckDevice(device, labels_upper_bound_);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CHECK_LE(num_nonzero_, num_col_ * num_row_);
|
CHECK_LE(num_nonzero_, num_col_ * num_row_);
|
||||||
if (base_margin_.Size() != 0) {
|
if (base_margin_.Size() != 0) {
|
||||||
CHECK_EQ(base_margin_.Size() % num_row_, 0)
|
CHECK_EQ(base_margin_.Size() % num_row_, 0)
|
||||||
<< "Size of base margin must be a multiple of number of rows.";
|
<< "Size of base margin must be a multiple of number of rows.";
|
||||||
check_device(*base_margin_.Data());
|
CheckDevice(device, base_margin_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
|||||||
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
|
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
|
||||||
auto& h_offset = this->offset.HostVector();
|
auto& h_offset = this->offset.HostVector();
|
||||||
auto& h_data = this->data.HostVector();
|
auto& h_data = this->data.HostVector();
|
||||||
|
n_threads = std::max(std::min(static_cast<std::size_t>(n_threads), this->Size()),
|
||||||
|
static_cast<std::size_t>(1));
|
||||||
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
|
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
|
||||||
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
|
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
|
||||||
auto beg = h_offset[i];
|
auto beg = h_offset[i];
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user