Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan
2021-02-02 11:41:46 +08:00
committed by GitHub
parent 87ab1ad607
commit 411592a347
22 changed files with 955 additions and 530 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2019~2020 by Contributors
* Copyright (c) 2019~2021 by Contributors
* \file adapter.h
*/
#ifndef XGBOOST_DATA_ADAPTER_H_
@@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
size_t num_columns_;
};
class ArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface array_interface_;
class Line {
ArrayInterface array_interface_;
size_t ridx_;
public:
Line(ArrayInterface array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
size_t Size() const { return array_interface_.num_cols; }
COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_.GetElement(idx)};
}
};
public:
ArrayAdapterBatch() = default;
Line const GetLine(size_t idx) const {
auto line = array_interface_.SliceRow(idx);
return Line{line, idx};
}
explicit ArrayAdapterBatch(ArrayInterface array_interface)
: array_interface_{std::move(array_interface)} {}
};
/**
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
* advantage is this can handle various data type without making a copy.
*/
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
public:
explicit ArrayAdapter(StringView array_interface) {
auto j = Json::Load(array_interface);
array_interface_ = ArrayInterface(get<Object const>(j));
batch_ = ArrayAdapterBatch{array_interface_};
}
ArrayAdapterBatch const& Value() const override { return batch_; }
size_t NumRows() const { return array_interface_.num_rows; }
size_t NumColumns() const { return array_interface_.num_cols; }
private:
ArrayAdapterBatch batch_;
ArrayInterface array_interface_;
};
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
class Line {
ArrayInterface indices_;
ArrayInterface values_;
size_t ridx_;
public:
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
COOTuple GetElement(size_t idx) const {
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
}
size_t Size() const {
return values_.num_rows * values_.num_cols;
}
};
public:
CSRArrayAdapterBatch() = default;
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {}
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx);
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
auto indices = indices_.SliceOffset(begin_offset);
auto values = values_.SliceOffset(begin_offset);
values.num_cols = end_offset - begin_offset;
values.num_rows = 1;
indices.num_cols = values.num_cols;
indices.num_rows = values.num_rows;
return Line{indices, values, idx};
}
};
/**
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
* advantage is this can handle various data type without making a copy.
*/
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
public:
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
size_t num_cols)
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
}
CSRArrayAdapterBatch const& Value() const override {
return batch_;
}
size_t NumRows() const {
size_t size = indptr_.num_cols * indptr_.num_rows;
size = size == 0 ? 0 : size - 1;
return size;
}
size_t NumColumns() const { return num_cols_; }
private:
CSRArrayAdapterBatch batch_;
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
size_t num_cols_;
};
class CSCAdapterBatch : public detail::NoMetaInfo {
public:
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019 by Contributors
* Copyright 2019-2021 by Contributors
* \file array_interface.h
* \brief View of __array_interface__
*/
@@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
}
}
static std::string UnSupportedType(const char (&typestr)[3]) {
static std::string UnSupportedType(StringView typestr) {
return TypeStr(typestr[1]) + " is not supported.";
}
};
@@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
}
}
template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
Validate(column);
@@ -257,16 +258,24 @@ class ArrayInterface {
}
auto typestr = get<String const>(column.at("typestr"));
type[0] = typestr.at(0);
type[1] = typestr.at(1);
type[2] = typestr.at(2);
this->CheckType();
this->AssignType(StringView{typestr});
}
public:
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
public:
ArrayInterface() = default;
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
auto jinterface = Json::Load({str.c_str(), str.size()});
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
explicit ArrayInterface(StringView str, bool allow_mask = true) {
auto jinterface = Json::Load(str);
if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask);
return;
@@ -279,71 +288,114 @@ class ArrayInterface {
}
}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
void CheckType() const {
if (type[1] == 'f' && type[2] == '4') {
return;
} else if (type[1] == 'f' && type[2] == '8') {
return;
} else if (type[1] == 'i' && type[2] == '1') {
return;
} else if (type[1] == 'i' && type[2] == '2') {
return;
} else if (type[1] == 'i' && type[2] == '4') {
return;
} else if (type[1] == 'i' && type[2] == '8') {
return;
} else if (type[1] == 'u' && type[2] == '1') {
return;
} else if (type[1] == 'u' && type[2] == '2') {
return;
} else if (type[1] == 'u' && type[2] == '4') {
return;
} else if (type[1] == 'u' && type[2] == '8') {
return;
void AssignType(StringView typestr) {
if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
type = kF8;
} else if (typestr[1] == 'i' && typestr[2] == '1') {
type = kI1;
} else if (typestr[1] == 'i' && typestr[2] == '2') {
type = kI2;
} else if (typestr[1] == 'i' && typestr[2] == '4') {
type = kI4;
} else if (typestr[1] == 'i' && typestr[2] == '8') {
type = kI8;
} else if (typestr[1] == 'u' && typestr[2] == '1') {
type = kU1;
} else if (typestr[1] == 'u' && typestr[2] == '2') {
type = kU2;
} else if (typestr[1] == 'u' && typestr[2] == '4') {
type = kU4;
} else if (typestr[1] == 'u' && typestr[2] == '8') {
type = kU8;
} else {
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
return;
}
}
XGBOOST_DEVICE float GetElement(size_t idx) const {
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values;
switch (type) {
case kF4:
p_values = reinterpret_cast<float *>(data) + offset;
break;
case kF8:
p_values = reinterpret_cast<double *>(data) + offset;
break;
case kI1:
p_values = reinterpret_cast<int8_t *>(data) + offset;
break;
case kI2:
p_values = reinterpret_cast<int16_t *>(data) + offset;
break;
case kI4:
p_values = reinterpret_cast<int32_t *>(data) + offset;
break;
case kI8:
p_values = reinterpret_cast<int64_t *>(data) + offset;
break;
case kU1:
p_values = reinterpret_cast<uint8_t *>(data) + offset;
break;
case kU2:
p_values = reinterpret_cast<uint16_t *>(data) + offset;
break;
case kU4:
p_values = reinterpret_cast<uint32_t *>(data) + offset;
break;
case kU8:
p_values = reinterpret_cast<uint64_t *>(data) + offset;
break;
}
ArrayInterface ret = *this;
ret.data = p_values;
return ret;
}
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
size_t offset = idx * num_cols;
auto ret = this->SliceOffset(offset);
ret.num_rows = 1;
return ret;
}
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
if (type[1] == 'f' && type[2] == '4') {
switch (type) {
case kF4:
return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') {
case kF8:
return reinterpret_cast<double*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '1') {
case kI1:
return reinterpret_cast<int8_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '2') {
case kI2:
return reinterpret_cast<int16_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '4') {
case kI4:
return reinterpret_cast<int32_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '8') {
case kI8:
return reinterpret_cast<int64_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '1') {
case kU1:
return reinterpret_cast<uint8_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '2') {
case kU2:
return reinterpret_cast<uint16_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '4') {
case kU4:
return reinterpret_cast<uint32_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '8') {
case kU8:
return reinterpret_cast<uint64_t*>(data)[idx];
} else {
SPAN_CHECK(false);
return 0;
}
SPAN_CHECK(false);
return reinterpret_cast<float*>(data)[idx];
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
void* data;
char type[3];
Type type;
};
} // namespace xgboost

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019 by XGBoost Contributors
* Copyright 2019-2021 by XGBoost Contributors
*
* \file data.cu
* \brief Handles setting metainfo from array interface.
@@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) {
} // anonymous namespace
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
<< "Expected integer metainfo";
CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8)
<< "Expected integer for group info.";
auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement(idx);
d_tmp[idx] = column.GetElement<size_t>(idx);
});
auto length = column.num_rows;
out->resize(length + 1);
@@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
auto it = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) {
return static_cast<uint32_t>(array_interface.GetElement(i));
return array_interface.GetElement<uint32_t>(i);
});
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
if (array_interface.GetElement<uint32_t>(i) >
array_interface.GetElement<uint32_t>(i + 1)) {
d_flag[0] = false;
}
});