From ee4f51a631e12523f7a6771d6446c99231c88b19 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 1 Jun 2021 08:34:48 +0800 Subject: [PATCH] Support for all primitive types from array. (#7003) * Change C API name. * Test for all primitive types from array. * Add native support for CPU 128 float. * Convert boolean and float16 in Python. * Fix dask version for now. --- include/xgboost/c_api.h | 2 +- python-package/xgboost/core.py | 18 +++++-- python-package/xgboost/data.py | 17 +++--- src/c_api/c_api.cc | 2 +- src/data/array_interface.h | 29 ++++++++--- tests/ci_build/conda_env/macos_cpu_test.yml | 4 +- tests/python-gpu/test_gpu_prediction.py | 58 +++++++++++++++++++++ tests/python/test_predict.py | 48 +++++++++++++++++ 8 files changed, 154 insertions(+), 24 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 90e4b185f..be3444252 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -142,7 +142,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, * \param out created dmatrix * \return 0 when success, -1 when failure happens */ -XGB_DLL int XGDMatrixCreateFromArray(char const *data, +XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *json_config, DMatrixHandle *out); diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 18090a3ac..0247e0ad8 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -239,7 +239,18 @@ def _array_interface(data: np.ndarray) -> bytes: interface = data.__array_interface__ if "mask" in interface: interface["mask"] = interface["mask"].__array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), "utf-8") + interface_str = bytes(json.dumps(interface), "utf-8") + return interface_str + + +def _cuda_array_interface(data) -> bytes: + assert ( + data.dtype.hasobject is False + ), "Input data contains `object` dtype. Expecting numeric data." + interface = data.__cuda_array_interface__ + if "mask" in interface: + interface["mask"] = interface["mask"].__cuda_array_interface__ + interface_str = bytes(json.dumps(interface), "utf-8") return interface_str @@ -1948,10 +1959,7 @@ class Booster(object): from .data import _transform_cupy_array data = _transform_cupy_array(data) - interface = data.__cuda_array_interface__ - if "mask" in interface: - interface["mask"] = interface["mask"].__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), "utf-8") + interface_str = _cuda_array_interface(data) _check_call( _LIB.XGBoosterPredictFromCudaArray( self.handle, diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 1c6305f38..2b0d0bb1f 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -9,7 +9,8 @@ from typing import Any import numpy as np -from .core import c_array, _LIB, _check_call, c_str, _array_interface +from .core import c_array, _LIB, _check_call, c_str +from .core import _array_interface, _cuda_array_interface from .core import DataIter, _ProxyDMatrix, DMatrix from .compat import lazy_isinstance @@ -105,7 +106,7 @@ def _is_numpy_array(data): def _ensure_np_dtype(data, dtype): - if data.dtype.hasobject: + if data.dtype.hasobject or data.dtype in [np.float16, np.bool_]: data = data.astype(np.float32, copy=False) dtype = np.float32 return data, dtype @@ -141,7 +142,7 @@ def _from_numpy_array(data, missing, nthread, feature_names, feature_types): } config = bytes(json.dumps(args), "utf-8") _check_call( - _LIB.XGDMatrixCreateFromArray( + _LIB.XGDMatrixCreateFromDense( _array_interface(data), config, ctypes.byref(handle), @@ -416,21 +417,19 @@ def _is_cupy_array(data): def _transform_cupy_array(data): + import cupy # pylint: disable=import-error if not hasattr(data, '__cuda_array_interface__') and hasattr( data, '__array__'): - import cupy # pylint: disable=import-error data = cupy.array(data, copy=False) + if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]: + data = data.astype(cupy.float32, copy=False) return data def _from_cupy_array(data, missing, nthread, feature_names, feature_types): """Initialize DMatrix from cupy ndarray.""" data = _transform_cupy_array(data) - interface = data.__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') - + interface_str = _cuda_array_interface(data) handle = ctypes.c_void_p() _check_call( _LIB.XGDMatrixCreateFromArrayInterface( diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 73c405ba7..32a5ae167 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -261,7 +261,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, API_END(); } -XGB_DLL int XGDMatrixCreateFromArray(char const *data, +XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *c_json_config, DMatrixHandle *out) { API_BEGIN(); diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 9d87f316d..35d82e8ed 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -42,7 +42,8 @@ struct ArrayInterfaceErrors { return str.c_str(); } static char const* Version() { - return "Only version <= 3 of `__cuda_array_interface__' are supported."; + return "Only version <= 3 of " + "`__cuda_array_interface__/__array_interface__' are supported."; } static char const* OfType(std::string const& type) { static std::string str; @@ -81,7 +82,7 @@ struct ArrayInterfaceErrors { return "Other"; default: LOG(FATAL) << "Invalid type code: " << c << " in `typestr' of input array." - << "\nPlease verify the `__cuda_array_interface__' " + << "\nPlease verify the `__cuda_array_interface__/__array_interface__' " << "of your input data complies to: " << "https://docs.scipy.org/doc/numpy/reference/arrays.interface.html" << "\nOr open an issue."; @@ -90,7 +91,7 @@ struct ArrayInterfaceErrors { } static std::string UnSupportedType(StringView typestr) { - return TypeStr(typestr[1]) + " is not supported."; + return TypeStr(typestr[1]) + "-" + typestr[2] + " is not supported."; } }; @@ -135,8 +136,9 @@ class ArrayInterfaceHandler { if (array.find("typestr") == array.cend()) { LOG(FATAL) << "Missing `typestr' field for array interface"; } + auto typestr = get(array.at("typestr")); - CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); + CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat(); CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); if (array.find("shape") == array.cend()) { @@ -295,7 +297,7 @@ class ArrayInterface { } public: - enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; + enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; public: ArrayInterface() = default; @@ -331,7 +333,12 @@ class ArrayInterface { } void AssignType(StringView typestr) { - if (typestr[1] == 'f' && typestr[2] == '4') { + if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && + typestr[3] == '6') { + type = kF16; + CHECK(sizeof(long double) == 16) + << "128-bit floating point is not supported on current platform."; + } else if (typestr[1] == 'f' && typestr[2] == '4') { type = kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { type = kF8; @@ -364,6 +371,16 @@ class ArrayInterface { return func(reinterpret_cast(data)); case kF8: return func(reinterpret_cast(data)); +#ifdef __CUDA_ARCH__ + case kF16: { + // CUDA device code doesn't support long double. + SPAN_CHECK(false); + return func(reinterpret_cast(data)); + } +#else + case kF16: + return func(reinterpret_cast(data)); +#endif case kI1: return func(reinterpret_cast(data)); case kI2: diff --git a/tests/ci_build/conda_env/macos_cpu_test.yml b/tests/ci_build/conda_env/macos_cpu_test.yml index 6a9092b2b..318005314 100644 --- a/tests/ci_build/conda_env/macos_cpu_test.yml +++ b/tests/ci_build/conda_env/macos_cpu_test.yml @@ -13,8 +13,8 @@ dependencies: - scikit-learn - pandas - matplotlib -- dask -- distributed +- dask=2021.05.0 +- distributed=2021.05.0 - graphviz - python-graphviz - hypothesis diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 71c27352e..ce73034a4 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -204,6 +204,7 @@ class TestGPUPredict: cpu_predt = reg.predict(X) np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6) + @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cudf()) def test_inplace_predict_cudf(self): import cupy as cp @@ -332,6 +333,7 @@ class TestGPUPredict: rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False) np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5) + @pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.parametrize("n_classes", [2, 3]) def test_predict_dart(self, n_classes): from sklearn.datasets import make_classification @@ -378,3 +380,59 @@ class TestGPUPredict: copied = cp.array(copied) cp.testing.assert_allclose(inplace, copied, atol=1e-6) + + @pytest.mark.skipif(**tm.no_cupy()) + def test_dtypes(self): + import cupy as cp + rows = 1000 + cols = 10 + rng = cp.random.RandomState(1994) + orig = rng.randint(low=0, high=127, size=rows * cols).reshape( + rows, cols + ) + y = rng.randint(low=0, high=127, size=rows) + dtrain = xgb.DMatrix(orig, label=y) + booster = xgb.train({"tree_method": "gpu_hist"}, dtrain) + + predt_orig = booster.inplace_predict(orig) + # all primitive types in numpy + for dtype in [ + cp.signedinteger, + cp.byte, + cp.short, + cp.intc, + cp.int_, + cp.longlong, + cp.unsignedinteger, + cp.ubyte, + cp.ushort, + cp.uintc, + cp.uint, + cp.ulonglong, + cp.floating, + cp.half, + cp.single, + cp.double, + ]: + X = cp.array(orig, dtype=dtype) + predt = booster.inplace_predict(X) + cp.testing.assert_allclose(predt, predt_orig) + + # boolean + orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape( + rows, cols + ) + predt_orig = booster.inplace_predict(orig) + for dtype in [cp.bool8, cp.bool_]: + X = cp.array(orig, dtype=dtype) + predt = booster.inplace_predict(X) + cp.testing.assert_allclose(predt, predt_orig) + + # unsupported types + for dtype in [ + cp.complex64, + cp.complex128, + ]: + X = cp.array(orig, dtype=dtype) + with pytest.raises(ValueError): + booster.inplace_predict(X) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 709a633bb..800a4838a 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -237,3 +237,51 @@ class TestInplacePredict: dtrain = xgb.DMatrix(self.X, self.y, base_margin=base_margin) from_dmatrix = booster.predict(dtrain) np.testing.assert_allclose(from_dmatrix, from_inplace) + + def test_dtypes(self): + orig = self.rng.randint(low=0, high=127, size=self.rows * self.cols).reshape( + self.rows, self.cols + ) + predt_orig = self.booster.inplace_predict(orig) + # all primitive types in numpy + for dtype in [ + np.signedinteger, + np.byte, + np.short, + np.intc, + np.int_, + np.longlong, + np.unsignedinteger, + np.ubyte, + np.ushort, + np.uintc, + np.uint, + np.ulonglong, + np.floating, + np.half, + np.single, + np.double, + ]: + X = np.array(orig, dtype=dtype) + predt = self.booster.inplace_predict(X) + np.testing.assert_allclose(predt, predt_orig) + + # boolean + orig = self.rng.binomial(1, 0.5, size=self.rows * self.cols).reshape( + self.rows, self.cols + ) + predt_orig = self.booster.inplace_predict(orig) + for dtype in [np.bool8, np.bool_]: + X = np.array(orig, dtype=dtype) + predt = self.booster.inplace_predict(X) + np.testing.assert_allclose(predt, predt_orig) + + # unsupported types + for dtype in [ + np.string_, + np.complex64, + np.complex128, + ]: + X = np.array(orig, dtype=dtype) + with pytest.raises(ValueError): + self.booster.inplace_predict(X)