Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* Copyright 2019-2021 by Contributors
|
||||
* \file array_interface.h
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
@@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string UnSupportedType(const char (&typestr)[3]) {
|
||||
static std::string UnSupportedType(StringView typestr) {
|
||||
return TypeStr(typestr[1]) + " is not supported.";
|
||||
}
|
||||
};
|
||||
@@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
Validate(column);
|
||||
@@ -257,16 +258,24 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
type[0] = typestr.at(0);
|
||||
type[1] = typestr.at(1);
|
||||
type[2] = typestr.at(2);
|
||||
this->CheckType();
|
||||
this->AssignType(StringView{typestr});
|
||||
}
|
||||
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
|
||||
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
explicit ArrayInterface(StringView str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load(str);
|
||||
if (IsA<Object>(jinterface)) {
|
||||
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||
return;
|
||||
@@ -279,71 +288,114 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
void CheckType() const {
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
return;
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
type = kF8;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '1') {
|
||||
type = kI1;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '2') {
|
||||
type = kI2;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '4') {
|
||||
type = kI4;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '8') {
|
||||
type = kI8;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '1') {
|
||||
type = kU1;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '2') {
|
||||
type = kU2;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '4') {
|
||||
type = kU4;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '8') {
|
||||
type = kU8;
|
||||
} else {
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values;
|
||||
switch (type) {
|
||||
case kF4:
|
||||
p_values = reinterpret_cast<float *>(data) + offset;
|
||||
break;
|
||||
case kF8:
|
||||
p_values = reinterpret_cast<double *>(data) + offset;
|
||||
break;
|
||||
case kI1:
|
||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
||||
break;
|
||||
case kI2:
|
||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
||||
break;
|
||||
case kI4:
|
||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
||||
break;
|
||||
case kI8:
|
||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
||||
break;
|
||||
case kU1:
|
||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
||||
break;
|
||||
case kU2:
|
||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
||||
break;
|
||||
case kU4:
|
||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
||||
break;
|
||||
case kU8:
|
||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
||||
break;
|
||||
}
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
} else {
|
||||
SPAN_CHECK(false);
|
||||
return 0;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
|
||||
Type type;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user