Use __array_interface__ for creating DMatrix from CSR. (#6675)
* Use __array_interface__ for creating DMatrix from CSR. * Add configuration.
This commit is contained in:
parent
1e949110da
commit
dbb5208a0a
@ -109,6 +109,27 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
size_t nelem,
|
||||
size_t num_col,
|
||||
DMatrixHandle* out);
|
||||
|
||||
/*!
|
||||
* \brief Create a matrix from CSR matrix.
|
||||
* \param indptr JSON encoded __array_interface__ to row pointers in CSR.
|
||||
* \param indices JSON encoded __array_interface__ to column indices in CSR.
|
||||
* \param data JSON encoded __array_interface__ to values in CSR.
|
||||
* \param num_col Number of columns.
|
||||
* \param json_config JSON encoded configuration. Required values are:
|
||||
*
|
||||
* - missing
|
||||
* - nthread
|
||||
*
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
bst_ulong ncol,
|
||||
char const* json_config,
|
||||
DMatrixHandle* out);
|
||||
|
||||
/*!
|
||||
* \brief create a matrix content from CSC format
|
||||
* \param col_ptr pointer to col headers
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import c_array, _LIB, _check_call, c_str
|
||||
from .core import c_array, _LIB, _check_call, c_str, _array_interface
|
||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
@ -40,21 +40,28 @@ def _is_scipy_csr(data):
|
||||
return isinstance(data, scipy.sparse.csr_matrix)
|
||||
|
||||
|
||||
def _from_scipy_csr(data, missing, feature_names, feature_types):
|
||||
'''Initialize data from a CSR matrix.'''
|
||||
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
||||
"""Initialize data from a CSR matrix."""
|
||||
if len(data.indices) != len(data.data):
|
||||
raise ValueError('length mismatch: {} vs {}'.format(
|
||||
len(data.indices), len(data.data)))
|
||||
_warn_unused_missing(data, missing)
|
||||
raise ValueError(
|
||||
"length mismatch: {} vs {}".format(len(data.indices), len(data.data))
|
||||
)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(
|
||||
c_array(ctypes.c_size_t, data.indptr),
|
||||
c_array(ctypes.c_uint, data.indices),
|
||||
c_array(ctypes.c_float, data.data),
|
||||
ctypes.c_size_t(len(data.indptr)),
|
||||
ctypes.c_size_t(len(data.data)),
|
||||
ctypes.c_size_t(data.shape[1]),
|
||||
ctypes.byref(handle)))
|
||||
args = {
|
||||
"missing": float(missing),
|
||||
"nthread": int(nthread),
|
||||
}
|
||||
config = bytes(json.dumps(args), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromCSR(
|
||||
_array_interface(data.indptr),
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
ctypes.c_size_t(data.shape[1]),
|
||||
config,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
@ -541,11 +548,11 @@ def dispatch_data_backend(data, missing, threads,
|
||||
enable_categorical=False):
|
||||
'''Dispatch data for DMatrix.'''
|
||||
if _is_scipy_csr(data):
|
||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
if _is_scipy_csc(data):
|
||||
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
||||
if _is_scipy_coo(data):
|
||||
return _from_scipy_csr(data.tocsr(), missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data.tocsr(), missing, threads, feature_names, feature_types)
|
||||
if _is_numpy_array(data):
|
||||
return _from_numpy_array(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
@ -592,7 +599,7 @@ def dispatch_data_backend(data, missing, threads,
|
||||
|
||||
converted = _convert_unknown_data(data)
|
||||
if converted:
|
||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
|
||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||
|
||||
|
||||
@ -246,6 +246,21 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong ncol,
|
||||
char const* c_json_config,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
||||
StringView{data}, ncol);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = get<Number const>(config["missing"]);
|
||||
auto nthread = get<Integer const>(config["nthread"]);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
|
||||
@ -306,6 +306,12 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
size_t Size() const {
|
||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
|
||||
@ -812,6 +812,9 @@ template DMatrix* DMatrix::Create<data::DataTableAdapter>(
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(
|
||||
data::FileAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(
|
||||
data::CSRArrayAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix *
|
||||
DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||
XGBoostBatchCSR> *adapter,
|
||||
|
||||
@ -209,6 +209,8 @@ template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
|
||||
@ -37,7 +37,7 @@ class TestDMatrix:
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
csr = csr_matrix(x)
|
||||
xgb.DMatrix(csr, y, missing=4)
|
||||
xgb.DMatrix(csr.tocsc(), y, missing=4)
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
@ -284,6 +284,31 @@ class TestDMatrix:
|
||||
bst = xgb.train(param, dtrain, 5, watchlist)
|
||||
bst.predict(dtrain)
|
||||
|
||||
i32 = csr_matrix((x.data.astype(np.int32), x.indices, x.indptr), shape=x.shape)
|
||||
f32 = csr_matrix(
|
||||
(i32.data.astype(np.float32), x.indices, x.indptr), shape=x.shape
|
||||
)
|
||||
di32 = xgb.DMatrix(i32)
|
||||
df32 = xgb.DMatrix(f32)
|
||||
dense = xgb.DMatrix(f32.toarray(), missing=0)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "f32.dmatrix")
|
||||
df32.save_binary(path)
|
||||
with open(path, "rb") as fd:
|
||||
df32_buffer = np.array(fd.read())
|
||||
path = os.path.join(tmpdir, "f32.dmatrix")
|
||||
di32.save_binary(path)
|
||||
with open(path, "rb") as fd:
|
||||
di32_buffer = np.array(fd.read())
|
||||
|
||||
path = os.path.join(tmpdir, "dense.dmatrix")
|
||||
dense.save_binary(path)
|
||||
with open(path, "rb") as fd:
|
||||
dense_buffer = np.array(fd.read())
|
||||
|
||||
np.testing.assert_equal(df32_buffer, di32_buffer)
|
||||
np.testing.assert_equal(df32_buffer, dense_buffer)
|
||||
|
||||
def test_sparse_dmatrix_csc(self):
|
||||
nrow = 1000
|
||||
ncol = 100
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user