[dask] Add DaskXGBRanker (#6576)

* Initial support for distributed LTR using dask.

* Support `qid` in libxgboost.
* Refactor `predict` and `n_features_in_`, `best_[score/iteration/ntree_limit]`
  to avoid duplicated code.
* Define `DaskXGBRanker`.

The dask ranker doesn't support group structure, instead it uses query id and
convert to group ptr internally.
This commit is contained in:
Jiaming Yuan
2021-01-08 18:35:09 +08:00
committed by GitHub
parent 96d3d32265
commit 80065d571e
18 changed files with 755 additions and 351 deletions

View File

@@ -498,6 +498,7 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Learner*>(handle)->Configure();
*out = static_cast<Learner*>(handle)->GetNumFeature();
API_END();
}

View File

@@ -374,13 +374,32 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
} else if (!std::strcmp(key, "group")) {
group_ptr_.resize(num + 1);
group_ptr_.clear(); group_ptr_.resize(num + 1, 0);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1));
group_ptr_[0] = 0;
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
} else if (!std::strcmp(key, "qid")) {
std::vector<uint32_t> query_ids(num, 0);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, query_ids.begin()));
bool non_dec = true;
for (size_t i = 1; i < query_ids.size(); ++i) {
if (query_ids[i] < query_ids[i-1]) {
non_dec = false;
break;
}
}
CHECK(non_dec) << "`qid` must be sorted in non-decreasing order along with data.";
group_ptr_.clear(); group_ptr_.push_back(0);
for (size_t i = 1; i < query_ids.size(); ++i) {
if (query_ids[i] != query_ids[i-1]) {
group_ptr_.push_back(i);
}
}
group_ptr_.push_back(query_ids.size());
} else if (!std::strcmp(key, "label_lower_bound")) {
auto& labels = labels_lower_bound_.HostVector();
labels.resize(num);

View File

@@ -34,16 +34,20 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
});
}
namespace {
auto SetDeviceToPtr(void *ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
}
} // anonymous namespace
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
<< "Expected integer metainfo";
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();
@@ -95,6 +99,47 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
return;
} else if (key == "qid") {
auto it = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) {
return static_cast<uint32_t>(array_interface.GetElement(i));
});
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
d_flag[0] = false;
}
});
bool non_dec = true;
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost));
CHECK(non_dec)
<< "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.num_rows);
dh::caching_device_vector<uint32_t> cnt(array_interface.num_rows);
HostDeviceVector<int> d_num_runs_out(1, 0, d);
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
dh::caching_device_vector<char> tmp(bytes);
cub::DeviceRunLengthEncode::Encode(tmp.data().get(), bytes, it, out.begin(),
cnt.begin(), d_num_runs_out.DevicePointer(),
array_interface.num_rows);
auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
group_ptr_.clear(); group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
return;

View File

@@ -436,7 +436,7 @@ class LearnerConfiguration : public Learner {
}
}
uint32_t GetNumFeature() override {
uint32_t GetNumFeature() const override {
return learner_model_param_.num_feature;
}