Small cleanups to various data types. (#8086)
- Use `bst_bin_t` in batch param constructor. - Use `StringView` to avoid `std::string` when appropriate. - Avoid using `MetaInfo` in quantile constructor to limit the scope of parameter.
This commit is contained in:
@@ -1137,16 +1137,15 @@ class SparsePageAdapterBatch {
|
||||
|
||||
public:
|
||||
struct Line {
|
||||
SparsePage::Inst inst;
|
||||
Entry const* inst;
|
||||
size_t n;
|
||||
bst_row_t ridx;
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return COOTuple{ridx, inst.data()[idx].index, inst.data()[idx].fvalue};
|
||||
}
|
||||
size_t Size() const { return inst.size(); }
|
||||
COOTuple GetElement(size_t idx) const { return {ridx, inst[idx].index, inst[idx].fvalue}; }
|
||||
size_t Size() const { return n; }
|
||||
};
|
||||
|
||||
explicit SparsePageAdapterBatch(HostSparsePageView page) : page_{std::move(page)} {}
|
||||
Line GetLine(size_t ridx) const { return Line{page_[ridx], ridx}; }
|
||||
Line GetLine(size_t ridx) const { return Line{page_[ridx].data(), page_[ridx].size(), ridx}; }
|
||||
size_t Size() const { return page_.Size(); }
|
||||
};
|
||||
}; // namespace data
|
||||
|
||||
@@ -92,9 +92,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
*/
|
||||
class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
public:
|
||||
explicit CudfAdapter(std::string cuda_interfaces_str) {
|
||||
Json interfaces =
|
||||
Json::Load({cuda_interfaces_str.c_str(), cuda_interfaces_str.size()});
|
||||
explicit CudfAdapter(StringView cuda_interfaces_str) {
|
||||
Json interfaces = Json::Load(cuda_interfaces_str);
|
||||
std::vector<Json> const& json_columns = get<Array>(interfaces);
|
||||
size_t n_columns = json_columns.size();
|
||||
CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0.";
|
||||
@@ -123,6 +122,9 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
columns_ = columns;
|
||||
batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_);
|
||||
}
|
||||
explicit CudfAdapter(std::string cuda_interfaces_str)
|
||||
: CudfAdapter{StringView{cuda_interfaces_str}} {}
|
||||
|
||||
const CudfAdapterBatch& Value() const override {
|
||||
CHECK_EQ(batch_.columns_.data(), columns_.data().get());
|
||||
return batch_;
|
||||
@@ -163,9 +165,8 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
|
||||
class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
public:
|
||||
explicit CupyAdapter(std::string cuda_interface_str) {
|
||||
Json json_array_interface =
|
||||
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
|
||||
explicit CupyAdapter(StringView cuda_interface_str) {
|
||||
Json json_array_interface = Json::Load(cuda_interface_str);
|
||||
array_interface_ = ArrayInterface<2>(get<Object const>(json_array_interface));
|
||||
batch_ = CupyAdapterBatch(array_interface_);
|
||||
if (array_interface_.Shape(0) == 0) {
|
||||
@@ -174,6 +175,8 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
||||
CHECK_NE(device_idx_, -1);
|
||||
}
|
||||
explicit CupyAdapter(std::string cuda_interface_str)
|
||||
: CupyAdapter{StringView{cuda_interface_str}} {}
|
||||
const CupyAdapterBatch& Value() const override { return batch_; }
|
||||
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
|
||||
std::shared_ptr<data::CudfAdapter> adapter {new data::CudfAdapter(interface_str)};
|
||||
void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
|
||||
std::shared_ptr<data::CudfAdapter> adapter{new CudfAdapter{interface_str}};
|
||||
auto const& value = adapter->Value();
|
||||
this->batch_ = adapter;
|
||||
ctx_.gpu_id = adapter->DeviceIdx();
|
||||
@@ -19,8 +19,8 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
|
||||
}
|
||||
}
|
||||
|
||||
void DMatrixProxy::FromCudaArray(std::string interface_str) {
|
||||
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
|
||||
void DMatrixProxy::FromCudaArray(StringView interface_str) {
|
||||
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter{StringView{interface_str}});
|
||||
this->batch_ = adapter;
|
||||
ctx_.gpu_id = adapter->DeviceIdx();
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
|
||||
@@ -48,8 +48,8 @@ class DMatrixProxy : public DMatrix {
|
||||
Context ctx_;
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
void FromCudaColumnar(std::string interface_str);
|
||||
void FromCudaArray(std::string interface_str);
|
||||
void FromCudaColumnar(StringView interface_str);
|
||||
void FromCudaArray(StringView interface_str);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
public:
|
||||
@@ -58,9 +58,8 @@ class DMatrixProxy : public DMatrix {
|
||||
void SetCUDAArray(char const* c_interface) {
|
||||
common::AssertGPUSupport();
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
std::string interface_str = c_interface;
|
||||
Json json_array_interface =
|
||||
Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
StringView interface_str{c_interface};
|
||||
Json json_array_interface = Json::Load(interface_str);
|
||||
if (IsA<Array>(json_array_interface)) {
|
||||
this->FromCudaColumnar(interface_str);
|
||||
} else {
|
||||
@@ -114,10 +113,11 @@ class DMatrixProxy : public DMatrix {
|
||||
}
|
||||
};
|
||||
|
||||
inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) {
|
||||
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy);
|
||||
inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
|
||||
auto proxy_handle = static_cast<std::shared_ptr<DMatrix>*>(proxy);
|
||||
CHECK(proxy_handle) << "Invalid proxy handle.";
|
||||
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get());
|
||||
DMatrixProxy* typed = static_cast<DMatrixProxy*>(proxy_handle->get());
|
||||
CHECK(typed) << "Invalid proxy handle.";
|
||||
return typed;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user