diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a382de4eb..ffcaa9e77 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -231,10 +231,11 @@ def c_array(ctype, values): return (ctype * len(values))(*values) -PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', - 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', - 'float16': 'float', 'float32': 'float', 'float64': 'float', - 'bool': 'i'} +PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': + 'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', + 'uint64': 'int', 'float16': 'float', 'float32': 'float', + 'float64': 'float', 'bool': 'i'} + # Either object has cuda array interface or contains columns with interfaces def _has_cuda_array_interface(data): @@ -242,7 +243,8 @@ def _has_cuda_array_interface(data): CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) -def _maybe_pandas_data(data, feature_names, feature_types): +def _maybe_pandas_data(data, feature_names, feature_types, + meta=None, meta_type=None): """Extract internal data from pd.DataFrame for DMatrix data""" if not (PANDAS_INSTALLED and isinstance(data, DataFrame)): @@ -250,51 +252,41 @@ def _maybe_pandas_data(data, feature_names, feature_types): data_dtypes = data.dtypes if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes): - bad_fields = [str(data.columns[i]) for i, dtype in - enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER] + bad_fields = [ + str(data.columns[i]) for i, dtype in enumerate(data_dtypes) + if dtype.name not in PANDAS_DTYPE_MAPPER + ] msg = """DataFrame.dtypes for data must be int, float or bool. Did not expect the data types in fields """ raise ValueError(msg + ', '.join(bad_fields)) - if feature_names is None: + if feature_names is None and meta is None: if isinstance(data.columns, MultiIndex): feature_names = [ - ' '.join([str(x) for x in i]) - for i in data.columns + ' '.join([str(x) for x in i]) for i in data.columns ] elif isinstance(data.columns, Int64Index): feature_names = list(map(str, data.columns)) else: feature_names = data.columns.format() - if feature_types is None: - feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes] + if feature_types is None and meta is None: + feature_types = [ + PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes + ] - data = data.values.astype('float') + if meta and len(data.columns) > 1: + raise ValueError( + 'DataFrame for {meta} cannot have multiple columns'.format( + meta=meta)) + + dtype = meta_type if meta_type else 'float' + data = data.values.astype(dtype) return data, feature_names, feature_types -def _maybe_pandas_label(label): - """Extract internal data from pd.DataFrame for DMatrix label.""" - - if PANDAS_INSTALLED and isinstance(label, DataFrame): - if len(label.columns) > 1: - raise ValueError( - 'DataFrame for label cannot have multiple columns') - - label_dtypes = label.dtypes - if not all(dtype.name in PANDAS_DTYPE_MAPPER - for dtype in label_dtypes): - raise ValueError( - 'DataFrame.dtypes for label must be int, float or bool') - label = label.values.astype('float') - # pd.Series can be passed to xgb as it is - - return label - - def _maybe_cudf_dataframe(data, feature_names, feature_types): """Extract internal data from cudf.DataFrame for DMatrix data.""" if not (CUDF_INSTALLED and isinstance(data, @@ -324,11 +316,21 @@ DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'} DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'} -def _maybe_dt_data(data, feature_names, feature_types): +def _maybe_dt_data(data, feature_names, feature_types, + meta=None, meta_type=None): """Validate feature names and types if data table""" if not isinstance(data, DataTable): return data, feature_names, feature_types + if meta and data.shape[1] > 1: + raise ValueError( + 'DataTable for label or weight cannot have multiple columns') + if meta: + # below requires new dt version + # extract first column + data = data.to_numpy()[:, 0].astype(meta_type) + return data, None, None + data_types_names = tuple(lt.name for lt in data.ltypes) bad_fields = [data.names[i] for i, type_name in enumerate(data_types_names) @@ -338,40 +340,31 @@ def _maybe_dt_data(data, feature_names, feature_types): Did not expect the data types in fields """ raise ValueError(msg + ', '.join(bad_fields)) - if feature_names is None: + if feature_names is None and meta is None: feature_names = data.names # always return stypes for dt ingestion if feature_types is not None: - raise ValueError('DataTable has own feature types, cannot pass them in') + raise ValueError( + 'DataTable has own feature types, cannot pass them in.') feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names) return data, feature_names, feature_types -def _maybe_dt_array(array): - """Extract numpy array from single column data table""" - if not isinstance(array, DataTable) or array is None: - return array - - if array.shape[1] > 1: - raise ValueError('DataTable for label or weight cannot have multiple columns') - - # below requires new dt version - # extract first column - array = array.to_numpy()[:, 0].astype('float') - - return array - - -def _convert_dataframes(data, feature_names, feature_types): +def _convert_dataframes(data, feature_names, feature_types, + meta=None, meta_type=None): data, feature_names, feature_types = _maybe_pandas_data(data, feature_names, - feature_types) + feature_types, + meta, + meta_type) data, feature_names, feature_types = _maybe_dt_data(data, feature_names, - feature_types) + feature_types, + meta, + meta_type) data, feature_names, feature_types = _maybe_cudf_dataframe( data, feature_names, feature_types) @@ -379,6 +372,23 @@ def _convert_dataframes(data, feature_names, feature_types): return data, feature_names, feature_types +def _maybe_np_slice(data, dtype=np.float32): + '''Handle numpy slice. This can be removed if we use __array_interface__. + ''' + try: + if not data.flags.c_contiguous: + warnings.warn( + "Use subset (sliced data) of np.ndarray is not recommended " + + "because it will generate extra copies and increase " + + "memory consumption") + data = np.array(data, copy=True, dtype=dtype) + else: + data = np.array(data, copy=False, dtype=dtype) + except AttributeError: + data = np.array(data, copy=False, dtype=dtype) + return data + + class DMatrix(object): """Data Matrix used in XGBoost. @@ -415,10 +425,10 @@ class DMatrix(object): .. note:: For ranking task, weights are per-group. - In ranking task, one weight is assigned to each group (not each data - point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign - weights to individual data points. + In ranking task, one weight is assigned to each group (not each + data point). This is because we only care about the relative + ordering of data points within each group, so it doesn't make + sense to assign weights to individual data points. silent : boolean, optional Whether print messages during construction @@ -429,6 +439,7 @@ class DMatrix(object): nthread : integer, optional Number of threads to use for loading data from numpy array. If -1, uses maximum threads available on the system. + """ # force into void_p, mac need to pass things in as void_p if data is None: @@ -447,11 +458,6 @@ class DMatrix(object): data, feature_names, feature_types ) - label = _maybe_pandas_label(label) - label = _maybe_dt_array(label) - weight = _maybe_dt_array(weight) - base_margin = _maybe_dt_array(base_margin) - if isinstance(data, (STRING_TYPES, os_PathLike)): handle = ctypes.c_void_p() _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)), @@ -491,41 +497,47 @@ class DMatrix(object): def _init_from_csr(self, csr): """Initialize data from a CSR matrix.""" if len(csr.indices) != len(csr.data): - raise ValueError('length mismatch: {} vs {}'.format(len(csr.indices), len(csr.data))) + raise ValueError('length mismatch: {} vs {}'.format( + len(csr.indices), len(csr.data))) handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromCSREx(c_array(ctypes.c_size_t, csr.indptr), - c_array(ctypes.c_uint, csr.indices), - c_array(ctypes.c_float, csr.data), - ctypes.c_size_t(len(csr.indptr)), - ctypes.c_size_t(len(csr.data)), - ctypes.c_size_t(csr.shape[1]), - ctypes.byref(handle))) + _check_call(_LIB.XGDMatrixCreateFromCSREx( + c_array(ctypes.c_size_t, csr.indptr), + c_array(ctypes.c_uint, csr.indices), + c_array(ctypes.c_float, csr.data), + ctypes.c_size_t(len(csr.indptr)), + ctypes.c_size_t(len(csr.data)), + ctypes.c_size_t(csr.shape[1]), + ctypes.byref(handle))) self.handle = handle def _init_from_csc(self, csc): """Initialize data from a CSC matrix.""" if len(csc.indices) != len(csc.data): - raise ValueError('length mismatch: {} vs {}'.format(len(csc.indices), len(csc.data))) + raise ValueError('length mismatch: {} vs {}'.format( + len(csc.indices), len(csc.data))) handle = ctypes.c_void_p() - _check_call(_LIB.XGDMatrixCreateFromCSCEx(c_array(ctypes.c_size_t, csc.indptr), - c_array(ctypes.c_uint, csc.indices), - c_array(ctypes.c_float, csc.data), - ctypes.c_size_t(len(csc.indptr)), - ctypes.c_size_t(len(csc.data)), - ctypes.c_size_t(csc.shape[0]), - ctypes.byref(handle))) + _check_call(_LIB.XGDMatrixCreateFromCSCEx( + c_array(ctypes.c_size_t, csc.indptr), + c_array(ctypes.c_uint, csc.indices), + c_array(ctypes.c_float, csc.data), + ctypes.c_size_t(len(csc.indptr)), + ctypes.c_size_t(len(csc.data)), + ctypes.c_size_t(csc.shape[0]), + ctypes.byref(handle))) self.handle = handle def _init_from_npy2d(self, mat, missing, nthread): """Initialize data from a 2-D numpy matrix. - If ``mat`` does not have ``order='C'`` (aka row-major) or is not contiguous, - a temporary copy will be made. + If ``mat`` does not have ``order='C'`` (aka row-major) or is + not contiguous, a temporary copy will be made. - If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be made. + If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will + be made. + + So there could be as many as two temporary data copies; be mindful of + input layout and type if memory use is a concern. - So there could be as many as two temporary data copies; be mindful of input layout - and type if memory use is a concern. """ if len(mat.shape) != 2: raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ', @@ -536,21 +548,14 @@ class DMatrix(object): data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32) handle = ctypes.c_void_p() missing = missing if missing is not None else np.nan - if nthread is None: - _check_call(_LIB.XGDMatrixCreateFromMat( - data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - c_bst_ulong(mat.shape[0]), - c_bst_ulong(mat.shape[1]), - ctypes.c_float(missing), - ctypes.byref(handle))) - else: - _check_call(_LIB.XGDMatrixCreateFromMat_omp( - data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - c_bst_ulong(mat.shape[0]), - c_bst_ulong(mat.shape[1]), - ctypes.c_float(missing), - ctypes.byref(handle), - nthread)) + nthread = nthread if nthread is not None else 1 + _check_call(_LIB.XGDMatrixCreateFromMat_omp( + data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + c_bst_ulong(mat.shape[0]), + c_bst_ulong(mat.shape[1]), + ctypes.c_float(missing), + ctypes.byref(handle), + c_bst_ulong(nthread))) self.handle = handle def _init_from_dt(self, data, nthread): @@ -572,7 +577,8 @@ class DMatrix(object): # always return stypes for dt ingestion feature_type_strings = (ctypes.c_char_p * data.ncols)() for icol in range(data.ncols): - feature_type_strings[icol] = ctypes.c_char_p(data.stypes[icol].name.encode('utf-8')) + feature_type_strings[icol] = ctypes.c_char_p( + data.stypes[icol].name.encode('utf-8')) handle = ctypes.c_void_p() _check_call(_LIB.XGDMatrixCreateFromDT( @@ -598,7 +604,8 @@ class DMatrix(object): _check_call( _LIB.XGDMatrixCreateFromArrayInterfaceColumns( interfaces_str, - ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle))) + ctypes.c_float(missing), ctypes.c_int(nthread), + ctypes.byref(handle))) self.handle = handle def _init_from_array_interface(self, data, missing, nthread): @@ -614,7 +621,8 @@ class DMatrix(object): _check_call( _LIB.XGDMatrixCreateFromArrayInterface( interface_str, - ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle))) + ctypes.c_float(missing), ctypes.c_int(nthread), + ctypes.byref(handle))) self.handle = handle def __del__(self): @@ -675,6 +683,7 @@ class DMatrix(object): data: numpy array The array of data to be set """ + data, _, _ = _convert_dataframes(data, None, None, field, 'float') if isinstance(data, np.ndarray): self.set_float_info_npy2d(field, data) return @@ -684,20 +693,6 @@ class DMatrix(object): c_data, c_bst_ulong(len(data)))) - def set_interface_info(self, field, data): - """Set info type property into DMatrix.""" - - # If we are passed a dataframe, extract the series - if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame): - if len(data.columns) != 1: - raise ValueError('Expecting meta-info to contain a single column') - data = data[data.columns[0]] - - interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8') - _check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle, - c_str(field), - interface)) - def set_float_info_npy2d(self, field, data): """Set float type property into the DMatrix for numpy 2d array input @@ -710,16 +705,7 @@ class DMatrix(object): data: numpy array The array of data to be set """ - try: - if not data.flags.c_contiguous: - warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " + - "because it will generate extra copies and increase " + - "memory consumption") - data = np.array(data, copy=True, dtype=np.float32) - else: - data = np.array(data, copy=False, dtype=np.float32) - except AttributeError: - data = np.array(data, copy=False, dtype=np.float32) + data = _maybe_np_slice(data, np.float32) c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, c_str(field), @@ -737,21 +723,29 @@ class DMatrix(object): data: numpy array The array of data to be set """ - try: - if not data.flags.c_contiguous: - warnings.warn("Use subset (sliced data) of np.ndarray is not recommended " + - "because it will generate extra copies and increase " + - "memory consumption") - data = np.array(data, copy=True, dtype=ctypes.c_uint) - else: - data = np.array(data, copy=False, dtype=ctypes.c_uint) - except AttributeError: - data = np.array(data, copy=False, dtype=ctypes.c_uint) + data = _maybe_np_slice(data, np.uint32) + data, _, _ = _convert_dataframes(data, None, None, field, 'uint32') + data = np.array(data, copy=False, dtype=ctypes.c_uint) _check_call(_LIB.XGDMatrixSetUIntInfo(self.handle, c_str(field), c_array(ctypes.c_uint, data), c_bst_ulong(len(data)))) + def set_interface_info(self, field, data): + """Set info type property into DMatrix.""" + # If we are passed a dataframe, extract the series + if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame): + if len(data.columns) != 1: + raise ValueError( + 'Expecting meta-info to contain a single column') + data = data[data.columns[0]] + + interface = bytes(json.dumps([data.__cuda_array_interface__], + indent=2), 'utf-8') + _check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle, + c_str(field), + interface)) + def save_binary(self, fname, silent=True): """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded by providing the path to :py:func:`xgboost.DMatrix` as input. @@ -775,26 +769,13 @@ class DMatrix(object): label: array like The label information to be set into DMatrix """ - if isinstance(label, np.ndarray): - self.set_label_npy2d(label) - elif _has_cuda_array_interface(label): + if _has_cuda_array_interface(label): self.set_interface_info('label', label) else: self.set_float_info('label', label) - def set_label_npy2d(self, label): - """Set label of dmatrix - - Parameters - ---------- - label: array like - The label information to be set into DMatrix - from numpy 2D array - """ - self.set_float_info_npy2d('label', label) - def set_weight(self, weight): - """ Set weight of each instance. + """Set weight of each instance. Parameters ---------- @@ -803,49 +784,30 @@ class DMatrix(object): .. note:: For ranking task, weights are per-group. - In ranking task, one weight is assigned to each group (not each data - point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign - weights to individual data points. + In ranking task, one weight is assigned to each group (not each + data point). This is because we only care about the relative + ordering of data points within each group, so it doesn't make + sense to assign weights to individual data points. + """ - if isinstance(weight, np.ndarray): - self.set_weight_npy2d(weight) - elif _has_cuda_array_interface(weight): + if _has_cuda_array_interface(weight): self.set_interface_info('weight', weight) else: self.set_float_info('weight', weight) - def set_weight_npy2d(self, weight): - """ Set weight of each instance - for numpy 2D array - - Parameters - ---------- - weight : array like - Weight for each data point in numpy 2D array - - .. note:: For ranking task, weights are per-group. - - In ranking task, one weight is assigned to each group (not each data - point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign - weights to individual data points. - """ - self.set_float_info_npy2d('weight', weight) - def set_base_margin(self, margin): - """ Set base margin of booster to start from. + """Set base margin of booster to start from. - This can be used to specify a prediction value of - existing model to be base_margin - However, remember margin is needed, instead of transformed prediction - e.g. for logistic regression: need to put in value before logistic transformation - see also example/demo.py + This can be used to specify a prediction value of existing model to be + base_margin However, remember margin is needed, instead of transformed + prediction e.g. for logistic regression: need to put in value before + logistic transformation see also example/demo.py Parameters ---------- margin: array like Prediction margin of each datapoint + """ if _has_cuda_array_interface(margin): self.set_interface_info('base_margin', margin) diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 9b25a03dc..117dcd243 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -920,7 +920,10 @@ class LambdaRankObj : public ObjFunction { std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels_.Size()); const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; CHECK(gptr.size() != 0 && gptr.back() == info.labels_.Size()) - << "group structure not consistent with #rows"; + << "group structure not consistent with #rows" << ", " + << "group ponter size: " << gptr.size() << ", " + << "labels size: " << info.labels_.Size() << ", " + << "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back()); #if defined(__CUDACC__) // Check if we have a GPU assignment; else, revert back to CPU diff --git a/tests/python/test_dt.py b/tests/python/test_dt.py index cd138fa53..d6dea3e7b 100644 --- a/tests/python/test_dt.py +++ b/tests/python/test_dt.py @@ -1,51 +1,54 @@ -# -*- coding: utf-8 -*- -import unittest -import pytest - -import testing as tm -import xgboost as xgb - -try: - import datatable as dt - import pandas as pd -except ImportError: - pass - -pytestmark = pytest.mark.skipif( - tm.no_dt()['condition'] or tm.no_pandas()['condition'], - reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason']) - - -class TestDataTable(unittest.TestCase): - - def test_dt(self): - df = pd.DataFrame([[1, 2., True], [2, 3., False]], - columns=['a', 'b', 'c']) - dtable = dt.Frame(df) - labels = dt.Frame([1, 2]) - dm = xgb.DMatrix(dtable, label=labels) - assert dm.feature_names == ['a', 'b', 'c'] - assert dm.feature_types == ['int', 'float', 'i'] - assert dm.num_row() == 2 - assert dm.num_col() == 3 - - # overwrite feature_names - dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]), - feature_names=['x', 'y', 'z']) - assert dm.feature_names == ['x', 'y', 'z'] - assert dm.num_row() == 2 - assert dm.num_col() == 3 - - # incorrect dtypes - df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], - columns=['a', 'b', 'c']) - dtable = dt.Frame(df) - self.assertRaises(ValueError, xgb.DMatrix, dtable) - - df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]}) - dtable = dt.Frame(df) - dm = xgb.DMatrix(dtable) - assert dm.feature_names == ['A=1', 'A=2'] - assert dm.feature_types == ['int', 'int'] - assert dm.num_row() == 3 - assert dm.num_col() == 2 +# -*- coding: utf-8 -*- +import unittest +import pytest +import numpy as np + +import testing as tm +import xgboost as xgb + +try: + import datatable as dt + import pandas as pd +except ImportError: + pass + +pytestmark = pytest.mark.skipif( + tm.no_dt()['condition'] or tm.no_pandas()['condition'], + reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason']) + + +class TestDataTable(unittest.TestCase): + + def test_dt(self): + df = pd.DataFrame([[1, 2., True], [2, 3., False]], + columns=['a', 'b', 'c']) + dtable = dt.Frame(df) + labels = dt.Frame([1, 2]) + dm = xgb.DMatrix(dtable, label=labels) + assert dm.feature_names == ['a', 'b', 'c'] + assert dm.feature_types == ['int', 'float', 'i'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + + np.testing.assert_array_equal(np.array([1, 2]), dm.get_label()) + + # overwrite feature_names + dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]), + feature_names=['x', 'y', 'z']) + assert dm.feature_names == ['x', 'y', 'z'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + + # incorrect dtypes + df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], + columns=['a', 'b', 'c']) + dtable = dt.Frame(df) + self.assertRaises(ValueError, xgb.DMatrix, dtable) + + df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]}) + dtable = dt.Frame(df) + dm = xgb.DMatrix(dtable) + assert dm.feature_names == ['A=1', 'A=2'] + assert dm.feature_types == ['int', 'int'] + assert dm.num_row() == 3 + assert dm.num_col() == 2 diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index e91627713..4da606269 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -29,6 +29,7 @@ class TestPandas(unittest.TestCase): assert dm.feature_types == ['int', 'float', 'i'] assert dm.num_row() == 2 assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([1, 2])) # overwrite feature_names and feature_types dm = xgb.DMatrix(df, label=pd.Series([1, 2]), @@ -51,6 +52,7 @@ class TestPandas(unittest.TestCase): assert dm.feature_types == ['int', 'float', 'i'] assert dm.num_row() == 2 assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([1, 2])) df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6]) dm = xgb.DMatrix(df, label=pd.Series([1, 2])) @@ -110,21 +112,38 @@ class TestPandas(unittest.TestCase): def test_pandas_label(self): # label must be a single column df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) - self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df) + self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df, + None, None, 'label', 'float') # label must be supported dtype df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) - self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df) + self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df, + None, None, 'label', 'float') df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) - result = xgb.core._maybe_pandas_label(df) + result, _, _ = xgb.core._maybe_pandas_data(df, None, None, + 'label', 'float') np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float)) - dm = xgb.DMatrix(np.random.randn(3, 2), label=df) assert dm.num_row() == 3 assert dm.num_col() == 2 + def test_pandas_weight(self): + kRows = 32 + kCols = 8 + + X = np.random.randn(kRows, kCols) + y = np.random.randn(kRows) + w = np.random.randn(kRows).astype(np.float32) + w_pd = pd.DataFrame(w) + data = xgb.DMatrix(X, y, w_pd) + + assert data.num_row() == kRows + assert data.num_col() == kCols + + np.testing.assert_array_equal(data.get_weight(), w) + def test_cv_as_pandas(self): dm = xgb.DMatrix(dpath + 'agaricus.txt.train') params = {'max_depth': 2, 'eta': 1, 'verbosity': 0, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 11f288384..19195d857 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -97,6 +97,7 @@ def test_ranking(): valid_data = xgb.DMatrix(x_valid, y_valid) test_data = xgb.DMatrix(x_test) train_data.set_group(train_group) + assert train_data.get_label().shape[0] == x_train.shape[0] valid_data.set_group(valid_group) params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',