enable ROCm on latest XGBoost
This commit is contained in:
28
src/data/adapter.cc
Normal file
28
src/data/adapter.cc
Normal file
@@ -0,0 +1,28 @@
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
*/
|
||||
#include "adapter.h"
|
||||
|
||||
#include "../c_api/c_api_error.h" // for API_BEGIN, API_END
|
||||
#include "xgboost/c_api.h"
|
||||
|
||||
namespace xgboost::data {
|
||||
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
||||
bool IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>::Next() {
|
||||
if ((*next_callback_)(
|
||||
data_handle_,
|
||||
[](void *handle, XGBoostBatchCSR batch) -> int {
|
||||
API_BEGIN();
|
||||
static_cast<IteratorAdapter *>(handle)->SetData(batch);
|
||||
API_END();
|
||||
},
|
||||
this) != 0) {
|
||||
at_first_ = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template class IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
|
||||
} // namespace xgboost::data
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2019~2021 by Contributors
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
@@ -16,11 +16,9 @@
|
||||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "../c_api/c_api_error.h"
|
||||
#include "../common/error_msg.h" // for MaxFeatureSize
|
||||
#include "../common/math.h"
|
||||
#include "array_interface.h"
|
||||
#include "arrow-cdi.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
@@ -743,8 +741,10 @@ class FileAdapter : dmlc::DataIter<FileAdapterBatch> {
|
||||
dmlc::Parser<uint32_t>* parser_;
|
||||
};
|
||||
|
||||
/*! \brief Data iterator that takes callback to return data, used in JVM package for
|
||||
* accepting data iterator. */
|
||||
/**
|
||||
* @brief Data iterator that takes callback to return data, used in JVM package for accepting data
|
||||
* iterator.
|
||||
*/
|
||||
template <typename DataIterHandle, typename XGBCallbackDataIterNext, typename XGBoostBatchCSR>
|
||||
class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
public:
|
||||
@@ -758,23 +758,9 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
CHECK(at_first_) << "Cannot reset IteratorAdapter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
if ((*next_callback_)(
|
||||
data_handle_,
|
||||
[](void *handle, XGBoostBatchCSR batch) -> int {
|
||||
API_BEGIN();
|
||||
static_cast<IteratorAdapter *>(handle)->SetData(batch);
|
||||
API_END();
|
||||
},
|
||||
this) != 0) {
|
||||
at_first_ = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
[[nodiscard]] bool Next() override;
|
||||
|
||||
FileAdapterBatch const& Value() const override {
|
||||
[[nodiscard]] FileAdapterBatch const& Value() const override {
|
||||
return *batch_.get();
|
||||
}
|
||||
|
||||
@@ -822,12 +808,12 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
block_.index = dmlc::BeginPtr(index_);
|
||||
block_.value = dmlc::BeginPtr(value_);
|
||||
|
||||
batch_.reset(new FileAdapterBatch(&block_, row_offset_));
|
||||
batch_ = std::make_unique<FileAdapterBatch>(&block_, row_offset_);
|
||||
row_offset_ += offset_.size() - 1;
|
||||
}
|
||||
|
||||
size_t NumColumns() const { return columns_; }
|
||||
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
[[nodiscard]] std::size_t NumColumns() const { return columns_; }
|
||||
[[nodiscard]] std::size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
|
||||
private:
|
||||
std::vector<size_t> offset_;
|
||||
@@ -849,356 +835,6 @@ class IteratorAdapter : public dmlc::DataIter<FileAdapterBatch> {
|
||||
std::unique_ptr<FileAdapterBatch> batch_;
|
||||
};
|
||||
|
||||
enum ColumnDType : uint8_t {
|
||||
kUnknown,
|
||||
kInt8,
|
||||
kUInt8,
|
||||
kInt16,
|
||||
kUInt16,
|
||||
kInt32,
|
||||
kUInt32,
|
||||
kInt64,
|
||||
kUInt64,
|
||||
kFloat,
|
||||
kDouble
|
||||
};
|
||||
|
||||
class Column {
|
||||
public:
|
||||
Column() = default;
|
||||
|
||||
Column(size_t col_idx, size_t length, size_t null_count, const uint8_t* bitmap)
|
||||
: col_idx_{col_idx}, length_{length}, null_count_{null_count}, bitmap_{bitmap} {}
|
||||
|
||||
virtual ~Column() = default;
|
||||
|
||||
Column(const Column&) = delete;
|
||||
Column& operator=(const Column&) = delete;
|
||||
Column(Column&&) = delete;
|
||||
Column& operator=(Column&&) = delete;
|
||||
|
||||
// whether the valid bit is set for this element
|
||||
bool IsValid(size_t row_idx) const {
|
||||
return (!bitmap_ || (bitmap_[row_idx/8] & (1 << (row_idx%8))));
|
||||
}
|
||||
|
||||
virtual COOTuple GetElement(size_t row_idx) const = 0;
|
||||
|
||||
virtual bool IsValidElement(size_t row_idx) const = 0;
|
||||
|
||||
virtual std::vector<float> AsFloatVector() const = 0;
|
||||
|
||||
virtual std::vector<uint64_t> AsUint64Vector() const = 0;
|
||||
|
||||
size_t Length() const { return length_; }
|
||||
|
||||
protected:
|
||||
size_t col_idx_;
|
||||
size_t length_;
|
||||
size_t null_count_;
|
||||
const uint8_t* bitmap_;
|
||||
};
|
||||
|
||||
// Only columns of primitive types are supported. An ArrowColumnarBatch is a
|
||||
// collection of std::shared_ptr<PrimitiveColumn>. These columns can be of different data types.
|
||||
// Hence, PrimitiveColumn is a class template; and all concrete PrimitiveColumns
|
||||
// derive from the abstract class Column.
|
||||
template <typename T>
|
||||
class PrimitiveColumn : public Column {
|
||||
static constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
|
||||
|
||||
public:
|
||||
PrimitiveColumn(size_t idx, size_t length, size_t null_count,
|
||||
const uint8_t* bitmap, const T* data, float missing)
|
||||
: Column{idx, length, null_count, bitmap}, data_{data}, missing_{missing} {}
|
||||
|
||||
COOTuple GetElement(size_t row_idx) const override {
|
||||
CHECK(data_ && row_idx < length_) << "Column is empty or out-of-bound index of the column";
|
||||
return { row_idx, col_idx_, IsValidElement(row_idx) ?
|
||||
static_cast<float>(data_[row_idx]) : kNaN };
|
||||
}
|
||||
|
||||
bool IsValidElement(size_t row_idx) const override {
|
||||
// std::isfinite needs to cast to double to prevent msvc report error
|
||||
return IsValid(row_idx)
|
||||
&& std::isfinite(static_cast<double>(data_[row_idx]))
|
||||
&& static_cast<float>(data_[row_idx]) != missing_;
|
||||
}
|
||||
|
||||
std::vector<float> AsFloatVector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<float> fv(length_);
|
||||
std::transform(data_, data_ + length_, fv.begin(),
|
||||
[](T v) { return static_cast<float>(v); });
|
||||
return fv;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> AsUint64Vector() const override {
|
||||
CHECK(data_) << "Column is empty";
|
||||
std::vector<uint64_t> iv(length_);
|
||||
std::transform(data_, data_ + length_, iv.begin(),
|
||||
[](T v) { return static_cast<uint64_t>(v); });
|
||||
return iv;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
float missing_; // user specified missing value
|
||||
};
|
||||
|
||||
struct ColumnarMetaInfo {
|
||||
// data type of the column
|
||||
ColumnDType type{ColumnDType::kUnknown};
|
||||
// location of the column in an Arrow record batch
|
||||
int64_t loc{-1};
|
||||
};
|
||||
|
||||
struct ArrowSchemaImporter {
|
||||
std::vector<ColumnarMetaInfo> columns;
|
||||
|
||||
// map Arrow format strings to types
|
||||
static ColumnDType FormatMap(char const* format_str) {
|
||||
CHECK(format_str) << "Format string cannot be empty";
|
||||
switch (format_str[0]) {
|
||||
case 'c':
|
||||
return ColumnDType::kInt8;
|
||||
case 'C':
|
||||
return ColumnDType::kUInt8;
|
||||
case 's':
|
||||
return ColumnDType::kInt16;
|
||||
case 'S':
|
||||
return ColumnDType::kUInt16;
|
||||
case 'i':
|
||||
return ColumnDType::kInt32;
|
||||
case 'I':
|
||||
return ColumnDType::kUInt32;
|
||||
case 'l':
|
||||
return ColumnDType::kInt64;
|
||||
case 'L':
|
||||
return ColumnDType::kUInt64;
|
||||
case 'f':
|
||||
return ColumnDType::kFloat;
|
||||
case 'g':
|
||||
return ColumnDType::kDouble;
|
||||
default:
|
||||
CHECK(false) << "Column data type not supported by XGBoost";
|
||||
return ColumnDType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
void Import(struct ArrowSchema *schema) {
|
||||
if (schema) {
|
||||
CHECK(std::string(schema->format) == "+s"); // NOLINT
|
||||
CHECK(columns.empty());
|
||||
for (auto i = 0; i < schema->n_children; ++i) {
|
||||
std::string name{schema->children[i]->name};
|
||||
ColumnDType type = FormatMap(schema->children[i]->format);
|
||||
ColumnarMetaInfo col_info{type, i};
|
||||
columns.push_back(col_info);
|
||||
}
|
||||
if (schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ArrowColumnarBatch {
|
||||
public:
|
||||
ArrowColumnarBatch(struct ArrowArray *rb, struct ArrowSchemaImporter* schema)
|
||||
: rb_{rb}, schema_{schema} {
|
||||
CHECK(rb_) << "Cannot import non-existent record batch";
|
||||
CHECK(!schema_->columns.empty()) << "Cannot import record batch without a schema";
|
||||
}
|
||||
|
||||
size_t Import(float missing) {
|
||||
auto& infov = schema_->columns;
|
||||
for (size_t i = 0; i < infov.size(); ++i) {
|
||||
columns_.push_back(CreateColumn(i, infov[i], missing));
|
||||
}
|
||||
|
||||
// Compute the starting location for every row in this batch
|
||||
auto batch_size = rb_->length;
|
||||
auto num_columns = columns_.size();
|
||||
row_offsets_.resize(batch_size + 1, 0);
|
||||
for (auto i = 0; i < batch_size; ++i) {
|
||||
row_offsets_[i+1] = row_offsets_[i];
|
||||
for (size_t j = 0; j < num_columns; ++j) {
|
||||
if (GetColumn(j).IsValidElement(i)) {
|
||||
row_offsets_[i+1]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
// return number of elements in the batch
|
||||
return row_offsets_.back();
|
||||
}
|
||||
|
||||
ArrowColumnarBatch(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch& operator=(const ArrowColumnarBatch&) = delete;
|
||||
ArrowColumnarBatch(ArrowColumnarBatch&&) = delete;
|
||||
ArrowColumnarBatch& operator=(ArrowColumnarBatch&&) = delete;
|
||||
|
||||
virtual ~ArrowColumnarBatch() {
|
||||
if (rb_ && rb_->release) {
|
||||
rb_->release(rb_);
|
||||
rb_ = nullptr;
|
||||
}
|
||||
columns_.clear();
|
||||
}
|
||||
|
||||
size_t Size() const { return rb_ ? rb_->length : 0; }
|
||||
|
||||
size_t NumColumns() const { return columns_.size(); }
|
||||
|
||||
size_t NumElements() const { return row_offsets_.back(); }
|
||||
|
||||
const Column& GetColumn(size_t col_idx) const {
|
||||
return *columns_[col_idx];
|
||||
}
|
||||
|
||||
void ShiftRowOffsets(size_t batch_offset) {
|
||||
std::transform(row_offsets_.begin(), row_offsets_.end(), row_offsets_.begin(),
|
||||
[=](size_t c) { return c + batch_offset; });
|
||||
}
|
||||
|
||||
const std::vector<size_t>& RowOffsets() const { return row_offsets_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Column> CreateColumn(size_t idx,
|
||||
ColumnarMetaInfo info,
|
||||
float missing) const {
|
||||
if (info.loc < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto loc_in_batch = info.loc;
|
||||
auto length = rb_->length;
|
||||
auto null_count = rb_->null_count;
|
||||
auto buffers0 = rb_->children[loc_in_batch]->buffers[0];
|
||||
auto buffers1 = rb_->children[loc_in_batch]->buffers[1];
|
||||
const uint8_t* bitmap = buffers0 ? reinterpret_cast<const uint8_t*>(buffers0) : nullptr;
|
||||
const uint8_t* data = buffers1 ? reinterpret_cast<const uint8_t*>(buffers1) : nullptr;
|
||||
|
||||
// if null_count is not computed, compute it here
|
||||
if (null_count < 0) {
|
||||
if (!bitmap) {
|
||||
null_count = 0;
|
||||
} else {
|
||||
null_count = length;
|
||||
for (auto i = 0; i < length; ++i) {
|
||||
if (bitmap[i/8] & (1 << (i%8))) {
|
||||
null_count--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (info.type) {
|
||||
case ColumnDType::kInt8:
|
||||
return std::make_shared<PrimitiveColumn<int8_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int8_t*>(data), missing);
|
||||
case ColumnDType::kUInt8:
|
||||
return std::make_shared<PrimitiveColumn<uint8_t>>(
|
||||
idx, length, null_count, bitmap, data, missing);
|
||||
case ColumnDType::kInt16:
|
||||
return std::make_shared<PrimitiveColumn<int16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int16_t*>(data), missing);
|
||||
case ColumnDType::kUInt16:
|
||||
return std::make_shared<PrimitiveColumn<uint16_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint16_t*>(data), missing);
|
||||
case ColumnDType::kInt32:
|
||||
return std::make_shared<PrimitiveColumn<int32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int32_t*>(data), missing);
|
||||
case ColumnDType::kUInt32:
|
||||
return std::make_shared<PrimitiveColumn<uint32_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint32_t*>(data), missing);
|
||||
case ColumnDType::kInt64:
|
||||
return std::make_shared<PrimitiveColumn<int64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const int64_t*>(data), missing);
|
||||
case ColumnDType::kUInt64:
|
||||
return std::make_shared<PrimitiveColumn<uint64_t>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const uint64_t*>(data), missing);
|
||||
case ColumnDType::kFloat:
|
||||
return std::make_shared<PrimitiveColumn<float>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const float*>(data), missing);
|
||||
case ColumnDType::kDouble:
|
||||
return std::make_shared<PrimitiveColumn<double>>(
|
||||
idx, length, null_count, bitmap,
|
||||
reinterpret_cast<const double*>(data), missing);
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
struct ArrowArray* rb_;
|
||||
struct ArrowSchemaImporter* schema_;
|
||||
std::vector<std::shared_ptr<Column>> columns_;
|
||||
std::vector<size_t> row_offsets_;
|
||||
};
|
||||
|
||||
using ArrowColumnarBatchVec = std::vector<std::unique_ptr<ArrowColumnarBatch>>;
|
||||
class RecordBatchesIterAdapter: public dmlc::DataIter<ArrowColumnarBatchVec> {
|
||||
public:
|
||||
RecordBatchesIterAdapter(XGDMatrixCallbackNext* next_callback, int nbatch)
|
||||
: next_callback_{next_callback}, nbatches_{nbatch} {}
|
||||
|
||||
void BeforeFirst() override {
|
||||
CHECK(at_first_) << "Cannot reset RecordBatchesIterAdapter";
|
||||
}
|
||||
|
||||
bool Next() override {
|
||||
batches_.clear();
|
||||
while (batches_.size() < static_cast<size_t>(nbatches_) && (*next_callback_)(this) != 0) {
|
||||
at_first_ = false;
|
||||
}
|
||||
|
||||
if (batches_.size() > 0) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void SetData(struct ArrowArray* rb, struct ArrowSchema* schema) {
|
||||
// Schema is only imported once at the beginning, regardless how many
|
||||
// baches are comming.
|
||||
// But even schema is not imported we still need to release its C data
|
||||
// exported from Arrow.
|
||||
if (at_first_ && schema) {
|
||||
schema_.Import(schema);
|
||||
} else {
|
||||
if (schema && schema->release) {
|
||||
schema->release(schema);
|
||||
}
|
||||
}
|
||||
if (rb) {
|
||||
batches_.push_back(std::make_unique<ArrowColumnarBatch>(rb, &schema_));
|
||||
}
|
||||
}
|
||||
|
||||
const ArrowColumnarBatchVec& Value() const override {
|
||||
return batches_;
|
||||
}
|
||||
|
||||
size_t NumColumns() const { return schema_.columns.size(); }
|
||||
size_t NumRows() const { return kAdapterUnknownSize; }
|
||||
|
||||
private:
|
||||
XGDMatrixCallbackNext *next_callback_;
|
||||
bool at_first_{true};
|
||||
int nbatches_;
|
||||
struct ArrowSchemaImporter schema_;
|
||||
ArrowColumnarBatchVec batches_;
|
||||
};
|
||||
|
||||
class SparsePageAdapterBatch {
|
||||
HostSparsePageView page_;
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/bitfield.h" // for RBitField8
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for NoF128
|
||||
#include "xgboost/base.h"
|
||||
@@ -106,7 +106,20 @@ struct ArrayInterfaceErrors {
|
||||
*/
|
||||
class ArrayInterfaceHandler {
|
||||
public:
|
||||
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
enum Type : std::int8_t {
|
||||
kF2 = 0,
|
||||
kF4 = 1,
|
||||
kF8 = 2,
|
||||
kF16 = 3,
|
||||
kI1 = 4,
|
||||
kI2 = 5,
|
||||
kI4 = 6,
|
||||
kI8 = 7,
|
||||
kU1 = 8,
|
||||
kU2 = 9,
|
||||
kU4 = 10,
|
||||
kU8 = 11,
|
||||
};
|
||||
|
||||
template <typename PtrType>
|
||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||
@@ -589,6 +602,57 @@ class ArrayInterface {
|
||||
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
||||
};
|
||||
|
||||
template <typename Fn>
|
||||
auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
|
||||
switch (dtype) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(__HIP_PLATFORM_AMD__)
|
||||
return dispatch(__half{});
|
||||
#else
|
||||
LOG(FATAL) << "half type is only supported for CUDA input.";
|
||||
break;
|
||||
#endif
|
||||
}
|
||||
case ArrayInterfaceHandler::kF4: {
|
||||
return dispatch(float{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kF8: {
|
||||
return dispatch(double{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(T) == 16) << error::NoF128();
|
||||
return dispatch(T{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
return dispatch(std::int8_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI2: {
|
||||
return dispatch(std::int16_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI4: {
|
||||
return dispatch(std::int32_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kI8: {
|
||||
return dispatch(std::int64_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU1: {
|
||||
return dispatch(std::uint8_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU2: {
|
||||
return dispatch(std::uint16_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU4: {
|
||||
return dispatch(std::uint32_t{});
|
||||
}
|
||||
case ArrayInterfaceHandler::kU8: {
|
||||
return dispatch(std::uint64_t{});
|
||||
}
|
||||
}
|
||||
|
||||
return std::result_of_t<Fn(std::int8_t)>();
|
||||
}
|
||||
|
||||
template <std::int32_t D, typename Fn>
|
||||
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
||||
// Only used for cuDF at the moment.
|
||||
@@ -604,60 +668,7 @@ void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
|
||||
std::numeric_limits<std::size_t>::max()},
|
||||
array.shape, array.strides, device});
|
||||
};
|
||||
switch (array.type) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(__HIP_PLATFORM_AMD__)
|
||||
dispatch(__half{});
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF4: {
|
||||
dispatch(float{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF8: {
|
||||
dispatch(double{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||
dispatch(T{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
dispatch(std::int8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI2: {
|
||||
dispatch(std::int16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI4: {
|
||||
dispatch(std::int32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI8: {
|
||||
dispatch(std::int64_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU1: {
|
||||
dispatch(std::uint8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU2: {
|
||||
dispatch(std::uint16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU4: {
|
||||
dispatch(std::uint32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU8: {
|
||||
dispatch(std::uint64_t{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
DispatchDType(array.type, dispatch);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
/* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define ARROW_FLAG_DICTIONARY_ORDERED 1
|
||||
#define ARROW_FLAG_NULLABLE 2
|
||||
#define ARROW_FLAG_MAP_KEYS_SORTED 4
|
||||
|
||||
struct ArrowSchema {
|
||||
// Array type description
|
||||
const char* format;
|
||||
const char* name;
|
||||
const char* metadata;
|
||||
int64_t flags;
|
||||
int64_t n_children;
|
||||
struct ArrowSchema** children;
|
||||
struct ArrowSchema* dictionary;
|
||||
|
||||
// Release callback
|
||||
void (*release)(struct ArrowSchema*);
|
||||
// Opaque producer-specific data
|
||||
void* private_data;
|
||||
};
|
||||
|
||||
struct ArrowArray {
|
||||
// Array data description
|
||||
int64_t length;
|
||||
int64_t null_count;
|
||||
int64_t offset;
|
||||
int64_t n_buffers;
|
||||
int64_t n_children;
|
||||
const void** buffers;
|
||||
struct ArrowArray** children;
|
||||
struct ArrowArray* dictionary;
|
||||
|
||||
// Release callback
|
||||
void (*release)(struct ArrowArray*);
|
||||
// Opaque producer-specific data
|
||||
void* private_data;
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
104
src/data/data.cc
104
src/data/data.cc
@@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
}
|
||||
|
||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||
if (size != 0 && this->num_col_ != 0) {
|
||||
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto elem = info[i];
|
||||
feature_type_names.emplace_back(elem);
|
||||
}
|
||||
if (IsColumnSplit()) {
|
||||
feature_type_names = collective::AllgatherStrings(feature_type_names);
|
||||
CHECK_EQ(feature_type_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
}
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
auto const rank = collective::GetRank();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
auto elem = std::to_string(rank) + "." + info[i];
|
||||
local_feature_names.emplace_back(elem);
|
||||
}
|
||||
feature_names = collective::AllgatherStrings(local_feature_names);
|
||||
CHECK_EQ(feature_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
} else {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown feature info name: " << key;
|
||||
@@ -687,13 +704,13 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
|
||||
linalg::Stack(&this->labels, that.labels);
|
||||
|
||||
this->weights_.SetDevice(that.weights_.DeviceIdx());
|
||||
this->weights_.SetDevice(that.weights_.Device());
|
||||
this->weights_.Extend(that.weights_);
|
||||
|
||||
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
|
||||
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.Device());
|
||||
this->labels_lower_bound_.Extend(that.labels_lower_bound_);
|
||||
|
||||
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
|
||||
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.Device());
|
||||
this->labels_upper_bound_.Extend(that.labels_upper_bound_);
|
||||
|
||||
linalg::Stack(&this->base_margin_, that.base_margin_);
|
||||
@@ -723,13 +740,13 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
if (!that.feature_weights.Empty()) {
|
||||
this->feature_weights.Resize(that.feature_weights.Size());
|
||||
this->feature_weights.SetDevice(that.feature_weights.DeviceIdx());
|
||||
this->feature_weights.SetDevice(that.feature_weights.Device());
|
||||
this->feature_weights.Copy(that.feature_weights);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
if (IsVerticalFederated()) {
|
||||
if (IsColumnSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||
} else {
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||
@@ -738,22 +755,22 @@ void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
|
||||
bool valid = v.Device().IsCPU() || device == Context::kCpuId || v.DeviceIdx() == device;
|
||||
void CheckDevice(DeviceOrd device, HostDeviceVector<T> const& v) {
|
||||
bool valid = v.Device().IsCPU() || device.IsCPU() || v.Device() == device;
|
||||
if (!valid) {
|
||||
LOG(FATAL) << "Invalid device ordinal. Data is associated with a different device ordinal than "
|
||||
"the booster. The device ordinal of the data is: "
|
||||
<< v.DeviceIdx() << "; the device ordinal of the Booster is: " << device;
|
||||
<< v.Device() << "; the device ordinal of the Booster is: " << device;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, std::int32_t D>
|
||||
void CheckDevice(std::int32_t device, linalg::Tensor<T, D> const& v) {
|
||||
void CheckDevice(DeviceOrd device, linalg::Tensor<T, D> const& v) {
|
||||
CheckDevice(device, *v.Data());
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void MetaInfo::Validate(std::int32_t device) const {
|
||||
void MetaInfo::Validate(DeviceOrd device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
|
||||
return;
|
||||
@@ -850,14 +867,6 @@ DMatrix* TryLoadBinary(std::string fname, bool silent) {
|
||||
} // namespace
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
|
||||
auto need_split = false;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
need_split = true;
|
||||
}
|
||||
|
||||
std::string fname, cache_file;
|
||||
auto dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
@@ -865,24 +874,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
CHECK_EQ(cache_file.find('#'), std::string::npos)
|
||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
std::ostringstream os;
|
||||
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
size_t pos = cache_shards[i].rfind('.');
|
||||
if (pos == std::string::npos) {
|
||||
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize();
|
||||
} else {
|
||||
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
|
||||
}
|
||||
if (i + 1 != cache_shards.size()) {
|
||||
os << ':';
|
||||
}
|
||||
}
|
||||
cache_file = os.str();
|
||||
}
|
||||
} else {
|
||||
fname = uri;
|
||||
}
|
||||
@@ -894,19 +885,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
partid = collective::GetRank();
|
||||
npart = collective::GetWorldSize();
|
||||
} else {
|
||||
// test option to load in part
|
||||
npart = 1;
|
||||
}
|
||||
|
||||
if (npart != 1) {
|
||||
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||
}
|
||||
|
||||
DMatrix* dmat{nullptr};
|
||||
DMatrix* dmat{};
|
||||
|
||||
if (cache_file.empty()) {
|
||||
fname = data::ValidateFileFormat(fname);
|
||||
@@ -916,6 +895,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file, data_split_mode);
|
||||
} else {
|
||||
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||
<< "Column-wise data split is not supported for external memory.";
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
|
||||
dmat = new data::SparsePageDMatrix{&iter,
|
||||
iter.Proxy(),
|
||||
@@ -926,17 +907,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
cache_file};
|
||||
}
|
||||
|
||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
LOG(CONSOLE) << "Splitting data by column";
|
||||
auto* sliced = dmat->SliceCol(npart, partid);
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
return dmat;
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
@@ -1011,9 +982,6 @@ template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter*
|
||||
template DMatrix* DMatrix::Create(
|
||||
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
||||
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
|
||||
DataSplitMode data_split_mode);
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
|
||||
@@ -33,13 +33,13 @@ template <typename T, int32_t D>
|
||||
void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tensor<T, D>* p_out) {
|
||||
ArrayInterface<D> array(arr_interface);
|
||||
if (array.n == 0) {
|
||||
p_out->SetDevice(0);
|
||||
p_out->SetDevice(DeviceOrd::CUDA(0));
|
||||
p_out->Reshape(array.shape);
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(array.valid.Capacity(), 0)
|
||||
<< "Meta info like label or weight can not have missing value.";
|
||||
auto ptr_device = SetDeviceToPtr(array.data);
|
||||
auto ptr_device = DeviceOrd::CUDA(SetDeviceToPtr(array.data));
|
||||
p_out->SetDevice(ptr_device);
|
||||
|
||||
if (array.is_contiguous && array.type == ToDType<T>::kType) {
|
||||
@@ -55,7 +55,7 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(DeviceOrd::CUDA(ptr_device));
|
||||
auto t = p_out->View(ptr_device);
|
||||
linalg::ElementWiseTransformDevice(
|
||||
t,
|
||||
[=] __device__(size_t i, T) {
|
||||
@@ -91,7 +91,7 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
auto d = DeviceOrd::CUDA(SetDeviceToPtr(array_interface.data));
|
||||
dh::LaunchN(1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(array_interface.Shape(0) - 1, [=] __device__(size_t i) {
|
||||
auto typed = TypedIndex<uint32_t, 1>{array_interface};
|
||||
|
||||
@@ -28,8 +28,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
CudfAdapterBatch(common::Span<ArrayInterface<1>> columns, size_t num_rows)
|
||||
: columns_(columns),
|
||||
num_rows_(num_rows) {}
|
||||
size_t Size() const { return num_rows_ * columns_.size(); }
|
||||
__device__ __forceinline__ COOTuple GetElement(size_t idx) const {
|
||||
[[nodiscard]] std::size_t Size() const { return num_rows_ * columns_.size(); }
|
||||
[[nodiscard]] __device__ __forceinline__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx = idx % columns_.size();
|
||||
size_t row_idx = idx / columns_.size();
|
||||
auto const& column = columns_[column_idx];
|
||||
@@ -39,7 +39,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
|
||||
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
auto const& column = columns_[fidx];
|
||||
float value = column.valid.Data() == nullptr || column.valid.Check(ridx)
|
||||
? column(ridx)
|
||||
@@ -47,8 +47,8 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
return value;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
||||
XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); }
|
||||
|
||||
private:
|
||||
common::Span<ArrayInterface<1>> columns_;
|
||||
@@ -120,16 +120,14 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
return;
|
||||
}
|
||||
|
||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||
CHECK_NE(device_idx_, Context::kCpuId);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||
|
||||
device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(first_column.data));
|
||||
CHECK(device_.IsCUDA());
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
for (auto& json_col : json_columns) {
|
||||
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
||||
columns.push_back(column);
|
||||
num_rows_ = std::max(num_rows_, column.Shape(0));
|
||||
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
||||
CHECK_EQ(device_.ordinal, dh::CudaGetPointerDevice(column.data))
|
||||
<< "All columns should use the same device.";
|
||||
CHECK_EQ(num_rows_, column.Shape(0))
|
||||
<< "All columns should have same number of rows.";
|
||||
@@ -145,15 +143,15 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
return batch_;
|
||||
}
|
||||
|
||||
size_t NumRows() const { return num_rows_; }
|
||||
size_t NumColumns() const { return columns_.size(); }
|
||||
int32_t DeviceIdx() const { return device_idx_; }
|
||||
[[nodiscard]] std::size_t NumRows() const { return num_rows_; }
|
||||
[[nodiscard]] std::size_t NumColumns() const { return columns_.size(); }
|
||||
[[nodiscard]] DeviceOrd Device() const { return device_; }
|
||||
|
||||
private:
|
||||
CudfAdapterBatch batch_;
|
||||
dh::device_vector<ArrayInterface<1>> columns_;
|
||||
size_t num_rows_{0};
|
||||
int32_t device_idx_{Context::kCpuId};
|
||||
DeviceOrd device_{DeviceOrd::CPU()};
|
||||
};
|
||||
|
||||
class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
@@ -161,22 +159,22 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
CupyAdapterBatch() = default;
|
||||
explicit CupyAdapterBatch(ArrayInterface<2> array_interface)
|
||||
: array_interface_(std::move(array_interface)) {}
|
||||
size_t Size() const {
|
||||
[[nodiscard]] std::size_t Size() const {
|
||||
return array_interface_.Shape(0) * array_interface_.Shape(1);
|
||||
}
|
||||
__device__ COOTuple GetElement(size_t idx) const {
|
||||
[[nodiscard]]__device__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx = idx % array_interface_.Shape(1);
|
||||
size_t row_idx = idx / array_interface_.Shape(1);
|
||||
float value = array_interface_(row_idx, column_idx);
|
||||
return {row_idx, column_idx, value};
|
||||
}
|
||||
__device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
[[nodiscard]] __device__ float GetElement(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
float value = array_interface_(ridx, fidx);
|
||||
return value;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
|
||||
XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.Shape(1); }
|
||||
|
||||
private:
|
||||
ArrayInterface<2> array_interface_;
|
||||
@@ -191,29 +189,28 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
if (array_interface_.Shape(0) == 0) {
|
||||
return;
|
||||
}
|
||||
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
||||
CHECK_NE(device_idx_, Context::kCpuId);
|
||||
device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(array_interface_.data));
|
||||
CHECK(device_.IsCUDA());
|
||||
}
|
||||
explicit CupyAdapter(std::string cuda_interface_str)
|
||||
: CupyAdapter{StringView{cuda_interface_str}} {}
|
||||
const CupyAdapterBatch& Value() const override { return batch_; }
|
||||
[[nodiscard]] const CupyAdapterBatch& Value() const override { return batch_; }
|
||||
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
size_t NumColumns() const { return array_interface_.Shape(1); }
|
||||
int32_t DeviceIdx() const { return device_idx_; }
|
||||
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
|
||||
[[nodiscard]] DeviceOrd Device() const { return device_; }
|
||||
|
||||
private:
|
||||
ArrayInterface<2> array_interface_;
|
||||
CupyAdapterBatch batch_;
|
||||
int32_t device_idx_ {Context::kCpuId};
|
||||
DeviceOrd device_{DeviceOrd::CPU()};
|
||||
};
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, int device_idx,
|
||||
std::size_t GetRowCounts(const AdapterBatchT batch, common::Span<bst_row_t> offset, DeviceOrd device,
|
||||
float missing) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
IsValidFunctor is_valid(missing);
|
||||
dh::safe_cuda(cudaMemsetAsync(offset.data(), '\0', offset.size_bytes()));
|
||||
|
||||
|
||||
@@ -98,23 +98,18 @@ __global__ void CompressBinEllpackKernel(
|
||||
}
|
||||
|
||||
// Construct an ELLPACK matrix with the given number of empty rows.
|
||||
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
bool is_dense, size_t row_stride,
|
||||
size_t n_rows)
|
||||
: is_dense(is_dense),
|
||||
cuts_(std::move(cuts)),
|
||||
row_stride(row_stride),
|
||||
n_rows(n_rows) {
|
||||
EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense,
|
||||
size_t row_stride, size_t n_rows)
|
||||
: is_dense(is_dense), cuts_(std::move(cuts)), row_stride(row_stride), n_rows(n_rows) {
|
||||
monitor_.Init("ellpack_page");
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
InitCompressedData(device);
|
||||
monitor_.Stop("InitCompressedData");
|
||||
}
|
||||
|
||||
EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
EllpackPageImpl::EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts,
|
||||
const SparsePage &page, bool is_dense,
|
||||
size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
@@ -128,7 +123,7 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param)
|
||||
: is_dense(dmat->IsDense()) {
|
||||
monitor_.Init("ellpack_page");
|
||||
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
|
||||
n_rows = dmat->Info().num_row_;
|
||||
|
||||
@@ -143,15 +138,15 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchP
|
||||
monitor_.Stop("Quantiles");
|
||||
|
||||
monitor_.Start("InitCompressedData");
|
||||
this->InitCompressedData(ctx->gpu_id);
|
||||
this->InitCompressedData(ctx->Device());
|
||||
monitor_.Stop("InitCompressedData");
|
||||
|
||||
dmat->Info().feature_types.SetDevice(ctx->gpu_id);
|
||||
dmat->Info().feature_types.SetDevice(ctx->Device());
|
||||
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
monitor_.Start("BinningCompression");
|
||||
CHECK(dmat->SingleColBlock());
|
||||
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
CreateHistIndices(ctx->gpu_id, batch, ft);
|
||||
CreateHistIndices(ctx->Device(), batch, ft);
|
||||
}
|
||||
monitor_.Stop("BinningCompression");
|
||||
}
|
||||
@@ -214,7 +209,7 @@ struct TupleScanOp {
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType const> feature_types,
|
||||
EllpackPageImpl* dst, int device_idx, float missing) {
|
||||
EllpackPageImpl* dst, DeviceOrd device, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ELLPACK matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
@@ -246,7 +241,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
auto device_accessor = dst->GetDeviceAccessor(device);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
|
||||
@@ -298,10 +293,9 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span<FeatureType cons
|
||||
#endif
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
void WriteNullValues(EllpackPageImpl* dst, DeviceOrd device, common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
auto device_accessor = dst->GetDeviceAccessor(device);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
auto row_stride = dst->row_stride;
|
||||
@@ -318,11 +312,11 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
}
|
||||
|
||||
template <typename AdapterBatch>
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense,
|
||||
EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense,
|
||||
common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types, size_t row_stride,
|
||||
size_t n_rows, common::HistogramCuts const& cuts) {
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
|
||||
*this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows);
|
||||
CopyDataToEllpack(batch, feature_types, this, device, missing);
|
||||
@@ -331,7 +325,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device,
|
||||
|
||||
#define ELLPACK_BATCH_SPECIALIZE(__BATCH_T) \
|
||||
template EllpackPageImpl::EllpackPageImpl( \
|
||||
__BATCH_T batch, float missing, int device, bool is_dense, \
|
||||
__BATCH_T batch, float missing, DeviceOrd device, bool is_dense, \
|
||||
common::Span<size_t> row_counts_span, common::Span<FeatureType const> feature_types, \
|
||||
size_t row_stride, size_t n_rows, common::HistogramCuts const& cuts);
|
||||
|
||||
@@ -388,9 +382,9 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
|
||||
[&](size_t i) { return page.row_ptr[i + 1] - page.row_ptr[i]; });
|
||||
row_stride = *std::max_element(it, it + page.Size());
|
||||
|
||||
CHECK_GE(ctx->gpu_id, 0);
|
||||
CHECK(ctx->IsCUDA());
|
||||
monitor_.Start("InitCompressedData");
|
||||
InitCompressedData(ctx->gpu_id);
|
||||
InitCompressedData(ctx->Device());
|
||||
monitor_.Stop("InitCompressedData");
|
||||
|
||||
// copy gidx
|
||||
@@ -400,7 +394,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_row_ptr.data(), page.row_ptr.data(), d_row_ptr.size_bytes(),
|
||||
cudaMemcpyHostToDevice, ctx->CUDACtx()->Stream()));
|
||||
|
||||
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
|
||||
auto accessor = this->GetDeviceAccessor(ctx->Device(), ft);
|
||||
auto null = accessor.NullValue();
|
||||
CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null);
|
||||
}
|
||||
@@ -425,8 +419,7 @@ struct CopyPage {
|
||||
};
|
||||
|
||||
// Copy the data from the given EllpackPage to the current page.
|
||||
size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page,
|
||||
size_t offset) {
|
||||
size_t EllpackPageImpl::Copy(DeviceOrd device, EllpackPageImpl const* page, size_t offset) {
|
||||
monitor_.Start("Copy");
|
||||
size_t num_elements = page->n_rows * page->row_stride;
|
||||
CHECK_EQ(row_stride, page->row_stride);
|
||||
@@ -486,7 +479,7 @@ struct CompactPage {
|
||||
};
|
||||
|
||||
// Compacts the data from the given EllpackPage into the current page.
|
||||
void EllpackPageImpl::Compact(int device, EllpackPageImpl const* page,
|
||||
void EllpackPageImpl::Compact(DeviceOrd device, EllpackPageImpl const* page,
|
||||
common::Span<size_t> row_indexes) {
|
||||
monitor_.Start("Compact");
|
||||
CHECK_EQ(row_stride, page->row_stride);
|
||||
@@ -499,13 +492,12 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl const* page,
|
||||
}
|
||||
|
||||
// Initialize the buffer to stored compressed features.
|
||||
void EllpackPageImpl::InitCompressedData(int device) {
|
||||
void EllpackPageImpl::InitCompressedData(DeviceOrd device) {
|
||||
size_t num_symbols = NumSymbols();
|
||||
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols);
|
||||
gidx_buffer.SetDevice(device);
|
||||
// Don't call fill unnecessarily
|
||||
if (gidx_buffer.Size() == 0) {
|
||||
@@ -517,7 +509,7 @@ void EllpackPageImpl::InitCompressedData(int device) {
|
||||
}
|
||||
|
||||
// Compress a CSR page into ELLPACK.
|
||||
void EllpackPageImpl::CreateHistIndices(int device,
|
||||
void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
|
||||
const SparsePage& row_batch,
|
||||
common::Span<FeatureType const> feature_types) {
|
||||
if (row_batch.Size() == 0) return;
|
||||
@@ -527,7 +519,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
|
||||
|
||||
// bin and compress entries in batches of rows
|
||||
size_t gpu_batch_nrows =
|
||||
std::min(dh::TotalMemory(device) / (16 * row_stride * sizeof(Entry)),
|
||||
std::min(dh::TotalMemory(device.ordinal) / (16 * row_stride * sizeof(Entry)),
|
||||
static_cast<size_t>(row_batch.Size()));
|
||||
|
||||
size_t gpu_nbatches = common::DivRoundUp(row_batch.Size(), gpu_batch_nrows);
|
||||
@@ -592,7 +584,7 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride,
|
||||
}
|
||||
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||
int device, common::Span<FeatureType const> feature_types) const {
|
||||
DeviceOrd device, common::Span<FeatureType const> feature_types) const {
|
||||
gidx_buffer.SetDevice(device);
|
||||
return {device,
|
||||
cuts_,
|
||||
@@ -606,7 +598,7 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
|
||||
}
|
||||
EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
|
||||
common::Span<FeatureType const> feature_types) const {
|
||||
return {Context::kCpuId,
|
||||
return {DeviceOrd::CPU(),
|
||||
cuts_,
|
||||
is_dense,
|
||||
row_stride,
|
||||
|
||||
@@ -35,16 +35,17 @@ struct EllpackDeviceAccessor {
|
||||
|
||||
common::Span<const FeatureType> feature_types;
|
||||
|
||||
EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts,
|
||||
bool is_dense, size_t row_stride, size_t base_rowid,
|
||||
size_t n_rows,common::CompressedIterator<uint32_t> gidx_iter,
|
||||
EllpackDeviceAccessor(DeviceOrd device, const common::HistogramCuts& cuts, bool is_dense,
|
||||
size_t row_stride, size_t base_rowid, size_t n_rows,
|
||||
common::CompressedIterator<uint32_t> gidx_iter,
|
||||
common::Span<FeatureType const> feature_types)
|
||||
: is_dense(is_dense),
|
||||
row_stride(row_stride),
|
||||
base_rowid(base_rowid),
|
||||
n_rows(n_rows) ,gidx_iter(gidx_iter),
|
||||
n_rows(n_rows),
|
||||
gidx_iter(gidx_iter),
|
||||
feature_types{feature_types} {
|
||||
if (device == Context::kCpuId) {
|
||||
if (device.IsCPU()) {
|
||||
gidx_fvalue_map = cuts.cut_values_.ConstHostSpan();
|
||||
feature_segments = cuts.cut_ptrs_.ConstHostSpan();
|
||||
min_fvalue = cuts.min_vals_.ConstHostSpan();
|
||||
@@ -59,7 +60,7 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
ridx -= base_rowid;
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
@@ -77,7 +78,7 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
|
||||
template <bool is_cat>
|
||||
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||
[[nodiscard]] __device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||
auto beg = feature_segments[column_id];
|
||||
auto end = feature_segments[column_id + 1];
|
||||
uint32_t idx = 0;
|
||||
@@ -99,7 +100,7 @@ struct EllpackDeviceAccessor {
|
||||
return idx;
|
||||
}
|
||||
|
||||
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
[[nodiscard]] __device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
@@ -108,18 +109,18 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
|
||||
// Check if the row id is withing range of the current batch.
|
||||
__device__ bool IsInRange(size_t row_id) const {
|
||||
[[nodiscard]] __device__ bool IsInRange(size_t row_id) const {
|
||||
return row_id >= base_rowid && row_id < base_rowid + n_rows;
|
||||
}
|
||||
/*! \brief Return the total number of symbols (total number of bins plus 1 for
|
||||
* not found). */
|
||||
XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
|
||||
XGBOOST_DEVICE size_t NullValue() const { return gidx_fvalue_map.size(); }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NullValue() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
XGBOOST_DEVICE size_t NumFeatures() const { return min_fvalue.size(); }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumFeatures() const { return min_fvalue.size(); }
|
||||
};
|
||||
|
||||
|
||||
@@ -141,14 +142,13 @@ class EllpackPageImpl {
|
||||
* This is used in the sampling case. The ELLPACK page is constructed from an existing EllpackInfo
|
||||
* and the given number of rows.
|
||||
*/
|
||||
EllpackPageImpl(int device, common::HistogramCuts cuts, bool is_dense,
|
||||
size_t row_stride, size_t n_rows);
|
||||
EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, bool is_dense, size_t row_stride,
|
||||
size_t n_rows);
|
||||
/*!
|
||||
* \brief Constructor used for external memory.
|
||||
*/
|
||||
EllpackPageImpl(int device, common::HistogramCuts cuts,
|
||||
const SparsePage &page, bool is_dense, size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types);
|
||||
EllpackPageImpl(DeviceOrd device, common::HistogramCuts cuts, const SparsePage& page,
|
||||
bool is_dense, size_t row_stride, common::Span<FeatureType const> feature_types);
|
||||
|
||||
/*!
|
||||
* \brief Constructor from an existing DMatrix.
|
||||
@@ -159,7 +159,7 @@ class EllpackPageImpl {
|
||||
explicit EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& parm);
|
||||
|
||||
template <typename AdapterBatch>
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, int device, bool is_dense,
|
||||
explicit EllpackPageImpl(AdapterBatch batch, float missing, DeviceOrd device, bool is_dense,
|
||||
common::Span<size_t> row_counts_span,
|
||||
common::Span<FeatureType const> feature_types, size_t row_stride,
|
||||
size_t n_rows, common::HistogramCuts const& cuts);
|
||||
@@ -176,7 +176,7 @@ class EllpackPageImpl {
|
||||
* @param offset The number of elements to skip before copying.
|
||||
* @returns The number of elements copied.
|
||||
*/
|
||||
size_t Copy(int device, EllpackPageImpl const *page, size_t offset);
|
||||
size_t Copy(DeviceOrd device, EllpackPageImpl const *page, size_t offset);
|
||||
|
||||
/*! \brief Compact the given ELLPACK page into the current page.
|
||||
*
|
||||
@@ -184,11 +184,10 @@ class EllpackPageImpl {
|
||||
* @param page The ELLPACK page to compact from.
|
||||
* @param row_indexes Row indexes for the compacted page.
|
||||
*/
|
||||
void Compact(int device, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
||||
|
||||
void Compact(DeviceOrd device, EllpackPageImpl const* page, common::Span<size_t> row_indexes);
|
||||
|
||||
/*! \return Number of instances in the page. */
|
||||
size_t Size() const;
|
||||
[[nodiscard]] size_t Size() const;
|
||||
|
||||
/*! \brief Set the base row id for this page. */
|
||||
void SetBaseRowId(std::size_t row_id) {
|
||||
@@ -204,12 +203,12 @@ class EllpackPageImpl {
|
||||
|
||||
/*! \brief Return the total number of symbols (total number of bins plus 1 for
|
||||
* not found). */
|
||||
size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
|
||||
[[nodiscard]] std::size_t NumSymbols() const { return cuts_.TotalBins() + 1; }
|
||||
|
||||
EllpackDeviceAccessor
|
||||
GetDeviceAccessor(int device,
|
||||
common::Span<FeatureType const> feature_types = {}) const;
|
||||
EllpackDeviceAccessor GetHostAccessor(common::Span<FeatureType const> feature_types = {}) const;
|
||||
[[nodiscard]] EllpackDeviceAccessor GetDeviceAccessor(
|
||||
DeviceOrd device, common::Span<FeatureType const> feature_types = {}) const;
|
||||
[[nodiscard]] EllpackDeviceAccessor GetHostAccessor(
|
||||
common::Span<FeatureType const> feature_types = {}) const;
|
||||
|
||||
private:
|
||||
/*!
|
||||
@@ -218,13 +217,13 @@ class EllpackPageImpl {
|
||||
* @param device The GPU device to use.
|
||||
* @param row_batch The CSR page.
|
||||
*/
|
||||
void CreateHistIndices(int device,
|
||||
void CreateHistIndices(DeviceOrd device,
|
||||
const SparsePage& row_batch,
|
||||
common::Span<FeatureType const> feature_types);
|
||||
/*!
|
||||
* \brief Initialize the buffer to store compressed features.
|
||||
*/
|
||||
void InitCompressedData(int device);
|
||||
void InitCompressedData(DeviceOrd device);
|
||||
|
||||
|
||||
public:
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
namespace xgboost::data {
|
||||
void EllpackPageSource::Fetch() {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
dh::safe_cuda(cudaSetDevice(device_.ordinal));
|
||||
if (!this->ReadCache()) {
|
||||
if (count_ != 0 && !sync_) {
|
||||
// source is initialized to be the 0th page during construction, so when count_ is 0
|
||||
|
||||
@@ -23,14 +23,14 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
|
||||
BatchParam param_;
|
||||
common::Span<FeatureType const> feature_types_;
|
||||
std::unique_ptr<common::HistogramCuts> cuts_;
|
||||
std::int32_t device_;
|
||||
DeviceOrd device_;
|
||||
|
||||
public:
|
||||
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
|
||||
std::shared_ptr<Cache> cache, BatchParam param,
|
||||
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
|
||||
common::Span<FeatureType const> feature_types,
|
||||
std::shared_ptr<SparsePageSource> source, std::int32_t device)
|
||||
std::shared_ptr<SparsePageSource> source, DeviceOrd device)
|
||||
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
|
||||
is_dense_{is_dense},
|
||||
row_stride_{row_stride},
|
||||
@@ -50,6 +50,7 @@ inline void EllpackPageSource::Fetch() {
|
||||
// silent the warning about unused variables.
|
||||
(void)(row_stride_);
|
||||
(void)(is_dense_);
|
||||
(void)(device_);
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@@ -36,8 +36,7 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
|
||||
auto pctx = MakeProxy(proxy_)->Ctx();
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(
|
||||
Args{{"nthread", std::to_string(nthread)}, {"device", pctx->DeviceName()}});
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"device", pctx->DeviceName()}});
|
||||
// hardcoded parameter.
|
||||
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
|
||||
|
||||
@@ -139,7 +138,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
||||
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
||||
size_t n_threads = ctx->Threads();
|
||||
size_t n_features = column_sizes.size();
|
||||
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
||||
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, DeviceOrd::CPU());
|
||||
column_sizes_tloc.Data()->Fill(0ul);
|
||||
auto view = column_sizes_tloc.HostView();
|
||||
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
|
||||
|
||||
@@ -48,10 +48,9 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
int32_t current_device;
|
||||
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
|
||||
auto get_device = [&]() -> int32_t {
|
||||
std::int32_t d = (ctx->gpu_id == Context::kCpuId) ? current_device : ctx->gpu_id;
|
||||
CHECK_NE(d, Context::kCpuId);
|
||||
auto get_device = [&]() {
|
||||
auto d = (ctx->IsCPU()) ? DeviceOrd::CUDA(current_device) : ctx->Device();
|
||||
CHECK(!d.IsCPU());
|
||||
return d;
|
||||
};
|
||||
|
||||
@@ -61,11 +60,8 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
common::HistogramCuts cuts;
|
||||
do {
|
||||
// We use do while here as the first batch is fetched in ctor
|
||||
// ctx_.gpu_id = proxy->DeviceIdx();
|
||||
CHECK_LT(ctx->gpu_id, common::AllVisibleGPUs());
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
|
||||
CHECK_LT(ctx->Ordinal(), common::AllVisibleGPUs());
|
||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
collective::Allreduce<collective::Operation::kMax>(&cols, 1);
|
||||
@@ -103,8 +99,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
auto n_features = cols;
|
||||
CHECK_GE(n_features, 1) << "Data must has at least 1 column.";
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
||||
if (!ref) {
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
common::SketchContainer final_sketch(
|
||||
@@ -143,9 +138,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
size_t n_batches_for_verification = 0;
|
||||
while (iter.Next()) {
|
||||
init_page();
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(get_device().ordinal));
|
||||
auto rows = num_rows();
|
||||
dh::device_vector<size_t> row_counts(rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||
@@ -197,18 +190,18 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
if (!ellpack_) {
|
||||
ellpack_.reset(new EllpackPage());
|
||||
if (ctx->IsCUDA()) {
|
||||
this->Info().feature_types.SetDevice(ctx->gpu_id);
|
||||
this->Info().feature_types.SetDevice(ctx->Device());
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
} else if (fmat_ctx_.IsCUDA()) {
|
||||
this->Info().feature_types.SetDevice(fmat_ctx_.gpu_id);
|
||||
this->Info().feature_types.SetDevice(fmat_ctx_.Device());
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&fmat_ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
} else {
|
||||
// Can happen when QDM is initialized on CPU, but a GPU version is queried by a different QDM
|
||||
// for cut reference.
|
||||
auto cuda_ctx = ctx->MakeCUDA();
|
||||
this->Info().feature_types.SetDevice(cuda_ctx.gpu_id);
|
||||
this->Info().feature_types.SetDevice(cuda_ctx.Device());
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&cuda_ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
}
|
||||
|
||||
@@ -11,18 +11,18 @@ void DMatrixProxy::SetArrayData(StringView interface_str) {
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
this->ctx_.gpu_id = Context::kCpuId;
|
||||
this->ctx_.Init(Args{{"device", "cpu"}});
|
||||
}
|
||||
|
||||
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
|
||||
char const *c_values, bst_feature_t n_features, bool on_host) {
|
||||
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, char const *c_values,
|
||||
bst_feature_t n_features, bool on_host) {
|
||||
CHECK(on_host) << "Not implemented on device.";
|
||||
std::shared_ptr<CSRArrayAdapter> adapter{new CSRArrayAdapter(
|
||||
StringView{c_indptr}, StringView{c_indices}, StringView{c_values}, n_features)};
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
this->ctx_.gpu_id = Context::kCpuId;
|
||||
this->ctx_.Init(Args{{"device", "cpu"}});
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
|
||||
@@ -11,13 +11,13 @@ void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
if (adapter->DeviceIdx() < 0) {
|
||||
if (adapter->Device().IsCPU()) {
|
||||
// empty data
|
||||
CHECK_EQ(this->Info().num_row_, 0);
|
||||
ctx_ = ctx_.MakeCUDA(dh::CurrentDevice());
|
||||
return;
|
||||
}
|
||||
ctx_ = ctx_.MakeCUDA(adapter->DeviceIdx());
|
||||
ctx_ = ctx_.MakeCUDA(adapter->Device().ordinal);
|
||||
}
|
||||
|
||||
void DMatrixProxy::FromCudaArray(StringView interface_str) {
|
||||
@@ -25,13 +25,13 @@ void DMatrixProxy::FromCudaArray(StringView interface_str) {
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
if (adapter->DeviceIdx() < 0) {
|
||||
if (adapter->Device().IsCPU()) {
|
||||
// empty data
|
||||
CHECK_EQ(this->Info().num_row_, 0);
|
||||
ctx_ = ctx_.MakeCUDA(dh::CurrentDevice());
|
||||
return;
|
||||
}
|
||||
ctx_ = ctx_.MakeCUDA(adapter->DeviceIdx());
|
||||
ctx_ = ctx_.MakeCUDA(adapter->Device().ordinal);
|
||||
}
|
||||
|
||||
namespace cuda_impl {
|
||||
|
||||
@@ -46,7 +46,7 @@ class DMatrixProxy : public DMatrix {
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
|
||||
public:
|
||||
int DeviceIdx() const { return ctx_.gpu_id; }
|
||||
DeviceOrd Device() const { return ctx_.Device(); }
|
||||
|
||||
void SetCUDAArray(char const* c_interface) {
|
||||
common::AssertGPUSupport();
|
||||
|
||||
@@ -75,11 +75,9 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsVerticalFederated()) {
|
||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
|
||||
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -253,7 +251,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
}
|
||||
if (batch.BaseMargin() != nullptr) {
|
||||
info_.base_margin_ = decltype(info_.base_margin_){
|
||||
batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, Context::kCpuId};
|
||||
batch.BaseMargin(), batch.BaseMargin() + batch.Size(), {batch.Size()}, DeviceOrd::CPU()};
|
||||
}
|
||||
if (batch.Qid() != nullptr) {
|
||||
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
|
||||
@@ -361,78 +359,4 @@ template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int n
|
||||
template SimpleDMatrix::SimpleDMatrix(
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, DataSplitMode data_split_mode);
|
||||
|
||||
template <>
|
||||
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode) {
|
||||
Context ctx;
|
||||
ctx.nthread = nthread;
|
||||
|
||||
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||
auto& data_vec = sparse_page_->data.HostVector();
|
||||
uint64_t total_batch_size = 0;
|
||||
uint64_t total_elements = 0;
|
||||
|
||||
adapter->BeforeFirst();
|
||||
// Iterate over batches of input data
|
||||
while (adapter->Next()) {
|
||||
auto& batches = adapter->Value();
|
||||
size_t num_elements = 0;
|
||||
size_t num_rows = 0;
|
||||
// Import Arrow RecordBatches
|
||||
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx.Threads())
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
num_elements += batches[i]->Import(missing);
|
||||
num_rows += batches[i]->Size();
|
||||
}
|
||||
total_elements += num_elements;
|
||||
total_batch_size += num_rows;
|
||||
// Compute global offset for every row and starting row for every batch
|
||||
std::vector<uint64_t> batch_offsets(batches.size());
|
||||
for (size_t i = 0; i < batches.size(); ++i) {
|
||||
if (i == 0) {
|
||||
batch_offsets[i] = total_batch_size - num_rows;
|
||||
batches[i]->ShiftRowOffsets(total_elements - num_elements);
|
||||
} else {
|
||||
batch_offsets[i] = batch_offsets[i - 1] + batches[i - 1]->Size();
|
||||
batches[i]->ShiftRowOffsets(batches[i - 1]->RowOffsets().back());
|
||||
}
|
||||
}
|
||||
// Pre-allocate DMatrix memory
|
||||
data_vec.resize(total_elements);
|
||||
offset_vec.resize(total_batch_size + 1);
|
||||
// Copy data into DMatrix
|
||||
#pragma omp parallel num_threads(ctx.Threads())
|
||||
{
|
||||
#pragma omp for nowait
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
size_t begin = batches[i]->RowOffsets()[0];
|
||||
for (size_t k = 0; k < batches[i]->Size(); ++k) {
|
||||
for (size_t j = 0; j < batches[i]->NumColumns(); ++j) {
|
||||
auto element = batches[i]->GetColumn(j).GetElement(k);
|
||||
if (!std::isnan(element.value)) {
|
||||
data_vec[begin++] = Entry(element.column_idx, element.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma omp for nowait
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) {
|
||||
auto& offsets = batches[i]->RowOffsets();
|
||||
std::copy(offsets.begin() + 1, offsets.end(), offset_vec.begin() + batch_offsets[i] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Synchronise worker columns
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.data_split_mode = data_split_mode;
|
||||
ReindexFeatures(&ctx);
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
|
||||
info_.num_row_ = total_batch_size;
|
||||
info_.num_nonzero_ = data_vec.size();
|
||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||
|
||||
fmat_ctx_ = ctx;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -10,9 +10,7 @@
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
namespace xgboost::data {
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
@@ -21,14 +19,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr
|
||||
DataSplitMode data_split_mode) {
|
||||
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||
<< "Column-wise data split is currently not supported on the GPU.";
|
||||
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
|
||||
: adapter->DeviceIdx();
|
||||
CHECK_GE(device, 0);
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
auto device = (adapter->Device().IsCPU() || adapter->NumRows() == 0)
|
||||
? DeviceOrd::CUDA(dh::CurrentDevice())
|
||||
: adapter->Device();
|
||||
CHECK(device.IsCUDA());
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
|
||||
Context ctx;
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"device", DeviceOrd::CUDA(device).Name()}});
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthread)}, {"device", device.Name()}});
|
||||
|
||||
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
||||
@@ -53,5 +51,4 @@ template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread, DataSplitMode data_split_mode);
|
||||
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
|
||||
int nthread, DataSplitMode data_split_mode);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -54,11 +54,9 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
int device_idx, float missing) {
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset, DeviceOrd device,
|
||||
float missing) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||
@@ -71,22 +69,19 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
});
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
thrust::exclusive_scan(thrust::hip::par(alloc),
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
thrust::exclusive_scan(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing,
|
||||
size_t CopyToSparsePage(AdapterBatchT const& batch, DeviceOrd device, float missing,
|
||||
SparsePage* page) {
|
||||
bool valid = NoInfInData(batch, IsValidFunctor{missing});
|
||||
CHECK(valid) << error::InfInData();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023, XGBoost Contributors
|
||||
* \file simple_dmatrix.h
|
||||
* \brief In-memory version of DMatrix.
|
||||
* \author Tianqi Chen
|
||||
@@ -15,8 +15,7 @@
|
||||
|
||||
#include "gradient_index.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
// Used for single batch data.
|
||||
class SimpleDMatrix : public DMatrix {
|
||||
public:
|
||||
@@ -65,9 +64,10 @@ class SimpleDMatrix : public DMatrix {
|
||||
/**
|
||||
* \brief Reindex the features based on a global view.
|
||||
*
|
||||
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
|
||||
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
|
||||
* reindex the features based on the offset needed to obtain the global view.
|
||||
* In some cases (e.g. column-wise data split and vertical federated learning), features are
|
||||
* loaded locally with indices starting from 0. However, all the algorithms assume the features
|
||||
* are globally indexed, so we reindex the features based on the offset needed to obtain the
|
||||
* global view.
|
||||
*/
|
||||
void ReindexFeatures(Context const* ctx);
|
||||
|
||||
@@ -75,6 +75,5 @@ class SimpleDMatrix : public DMatrix {
|
||||
// Context used only for DMatrix initialization.
|
||||
Context fmat_ctx_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_H_
|
||||
|
||||
@@ -45,7 +45,8 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
ellpack_page_source_.reset(); // make sure resource is released before making new ones.
|
||||
ellpack_page_source_ = std::make_shared<EllpackPageSource>(
|
||||
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
|
||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_, ctx->gpu_id);
|
||||
param, std::move(cuts), this->IsDense(), row_stride, ft, sparse_page_source_,
|
||||
ctx->Device());
|
||||
} else {
|
||||
CHECK(sparse_page_source_);
|
||||
ellpack_page_source_->Reset();
|
||||
|
||||
@@ -19,11 +19,11 @@ std::size_t NFeaturesDevice(DMatrixProxy *proxy) {
|
||||
} // namespace detail
|
||||
|
||||
void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
|
||||
auto device = proxy->DeviceIdx();
|
||||
if (device < 0) {
|
||||
device = dh::CurrentDevice();
|
||||
auto device = proxy->Device();
|
||||
if (device.IsCPU()) {
|
||||
device = DeviceOrd::CUDA(dh::CurrentDevice());
|
||||
}
|
||||
CHECK_GE(device, 0);
|
||||
CHECK(device.IsCUDA());
|
||||
|
||||
cuda_impl::Dispatch(proxy,
|
||||
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
|
||||
|
||||
@@ -177,15 +177,15 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
|
||||
}
|
||||
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
|
||||
// to let user adjust number of pre-fetched batches when needed.
|
||||
uint32_t constexpr kPreFetch = 3;
|
||||
|
||||
size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
|
||||
std::int32_t n_prefetches = std::max(nthreads_, 3);
|
||||
std::int32_t n_prefetch_batches =
|
||||
std::min(static_cast<std::uint32_t>(n_prefetches), n_batches_);
|
||||
CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_;
|
||||
std::size_t fetch_it = count_;
|
||||
|
||||
exce_.Rethrow();
|
||||
|
||||
for (std::size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||
for (std::int32_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
|
||||
fetch_it %= n_batches_; // ring
|
||||
if (ring_->at(fetch_it).valid()) {
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user