Cleanup set info. (#10139)

- Use the array interface internally.
- Deprecate `XGDMatrixSetDenseInfo`.
- Deprecate `XGDMatrixSetUIntInfo`.
- Move the handling of `DataType` into the deprecated C function.

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2024-03-26 23:26:24 +08:00
committed by GitHub
parent 6a7c6a8ae6
commit 230010d9a0
37 changed files with 246 additions and 268 deletions

View File

@@ -11,7 +11,6 @@
#include <cmath> // for abs
#include <cstdint> // for uint64_t, int32_t, uint8_t, uint32_t
#include <cstring> // for size_t, strcmp, memcpy
#include <exception> // for exception
#include <iostream> // for operator<<, basic_ostream, basic_ostream::op...
#include <map> // for map, operator!=
#include <numeric> // for accumulate, partial_sum
@@ -22,7 +21,6 @@
#include "../collective/communicator.h" // for Operation
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/common.h" // for Split
#include "../common/error_msg.h" // for GroupSize, GroupWeight, InfInData
#include "../common/group_data.h" // for ParallelGroupBuilder
#include "../common/io.h" // for PeekableInStream
@@ -473,11 +471,11 @@ void MetaInfo::SetInfo(Context const& ctx, StringView key, StringView interface_
<< ", must have at least 1 column even if it's empty.";
auto const& first = get<Object const>(array.front());
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
} else {
auto const& first = get<Object const>(j_interface);
auto ptr = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(first);
is_cuda = ArrayInterfaceHandler::IsCudaPtr(ptr);
is_cuda = first.find("stream") != first.cend() || ArrayInterfaceHandler::IsCudaPtr(ptr);
}
if (is_cuda) {
@@ -567,46 +565,6 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
}
}
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
size_t num) {
CHECK(key);
auto proc = [&](auto cast_d_ptr) {
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, DeviceOrd::CPU());
CHECK(t.CContiguous());
Json interface {
linalg::ArrayInterface(t)
};
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};
// Legacy code using XGBoost dtype, which is a small subset of array interface types.
switch (dtype) {
case xgboost::DataType::kFloat32: {
auto cast_ptr = reinterpret_cast<const float*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kDouble: {
auto cast_ptr = reinterpret_cast<const double*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt32: {
auto cast_ptr = reinterpret_cast<const uint32_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
case xgboost::DataType::kUInt64: {
auto cast_ptr = reinterpret_cast<const uint64_t*>(dptr);
this->SetInfoFromHost(ctx, key, proc(cast_ptr));
break;
}
default:
LOG(FATAL) << "Unknown data type" << static_cast<uint8_t>(dtype);
}
}
void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
const void** out_dptr) const {
if (dtype == DataType::kFloat32) {

View File

@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#include "file_iterator.h"
@@ -10,7 +10,10 @@
#include <ostream> // for operator<<, basic_ostream, istringstream
#include <vector> // for vector
#include "../common/common.h" // for Split
#include "../common/common.h" // for Split
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/linalg.h"
#include "xgboost/logging.h" // for CHECK
#include "xgboost/string_view.h" // for operator<<, StringView
namespace xgboost::data {
@@ -28,10 +31,10 @@ std::string ValidateFileFormat(std::string const& uri) {
for (size_t i = 0; i < arg_list.size(); ++i) {
std::istringstream is(arg_list[i]);
std::pair<std::string, std::string> kv;
CHECK(std::getline(is, kv.first, '=')) << "Invalid uri argument format"
<< " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second)) << "Invalid uri argument format"
<< " for value in arg " << i + 1;
CHECK(std::getline(is, kv.first, '='))
<< "Invalid uri argument format" << " for key in arg " << i + 1;
CHECK(std::getline(is, kv.second))
<< "Invalid uri argument format" << " for value in arg " << i + 1;
args.insert(kv);
}
if (args.find("format") == args.cend()) {
@@ -48,4 +51,41 @@ std::string ValidateFileFormat(std::string const& uri) {
return name_args[0] + "?" + name_args[1] + '#' + name_args_cache[1];
}
}
int FileIterator::Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
indptr_ = linalg::Make1dInterface(row_block_.offset, row_block_.size + 1);
values_ = linalg::Make1dInterface(row_block_.value, row_block_.offset[row_block_.size]);
indices_ = linalg::Make1dInterface(row_block_.index, row_block_.offset[row_block_.size]);
size_t n_columns =
*std::max_element(row_block_.index, row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(), values_.c_str(), n_columns);
if (row_block_.label) {
auto str = linalg::Make1dInterface(row_block_.label, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
}
if (row_block_.qid) {
auto str = linalg::Make1dInterface(row_block_.qid, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
}
if (row_block_.weight) {
auto str = linalg::Make1dInterface(row_block_.weight, row_block_.size);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
} // namespace xgboost::data

View File

@@ -1,20 +1,16 @@
/**
* Copyright 2021-2023, XGBoost contributors
* Copyright 2021-2024, XGBoost contributors
*/
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <algorithm> // for max_element
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for unique_ptr
#include <string> // for string
#include <utility> // for move
#include "dmlc/data.h" // for RowBlock, Parser
#include "xgboost/c_api.h" // for XGDMatrixSetDenseInfo, XGDMatrixFree, XGProxyDMatrixCreate
#include "xgboost/linalg.h" // for ArrayInterfaceStr, MakeVec
#include "xgboost/logging.h" // for CHECK
#include "xgboost/c_api.h" // for XGDMatrixFree, XGProxyDMatrixCreate
namespace xgboost::data {
[[nodiscard]] std::string ValidateFileFormat(std::string const& uri);
@@ -53,41 +49,7 @@ class FileIterator {
XGDMatrixFree(proxy_);
}
int Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
values_.c_str(), n_columns);
if (row_block_.label) {
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
}
if (row_block_.qid) {
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
}
if (row_block_.weight) {
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
int Next();
auto Proxy() -> decltype(proxy_) { return proxy_; }