Fix metainfo from DataFrame. (#5216)
* Fix metainfo from DataFrame. * Unify helper functions for data and meta.
This commit is contained in:
@@ -1,51 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import testing as tm
|
||||
import xgboost as xgb
|
||||
|
||||
try:
|
||||
import datatable as dt
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
|
||||
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
|
||||
|
||||
|
||||
class TestDataTable(unittest.TestCase):
|
||||
|
||||
def test_dt(self):
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
||||
columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
labels = dt.Frame([1, 2])
|
||||
dm = xgb.DMatrix(dtable, label=labels)
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# overwrite feature_names
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
|
||||
columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, dtable)
|
||||
|
||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
dtable = dt.Frame(df)
|
||||
dm = xgb.DMatrix(dtable)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import testing as tm
|
||||
import xgboost as xgb
|
||||
|
||||
try:
|
||||
import datatable as dt
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
tm.no_dt()['condition'] or tm.no_pandas()['condition'],
|
||||
reason=tm.no_dt()['reason'] + ' or ' + tm.no_pandas()['reason'])
|
||||
|
||||
|
||||
class TestDataTable(unittest.TestCase):
|
||||
|
||||
def test_dt(self):
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]],
|
||||
columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
labels = dt.Frame([1, 2])
|
||||
dm = xgb.DMatrix(dtable, label=labels)
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
np.testing.assert_array_equal(np.array([1, 2]), dm.get_label())
|
||||
|
||||
# overwrite feature_names
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
|
||||
columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, dtable)
|
||||
|
||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
dtable = dt.Frame(df)
|
||||
dm = xgb.DMatrix(dtable)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
@@ -29,6 +29,7 @@ class TestPandas(unittest.TestCase):
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||
|
||||
# overwrite feature_names and feature_types
|
||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]),
|
||||
@@ -51,6 +52,7 @@ class TestPandas(unittest.TestCase):
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||
|
||||
df = pd.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
|
||||
dm = xgb.DMatrix(df, label=pd.Series([1, 2]))
|
||||
@@ -110,21 +112,38 @@ class TestPandas(unittest.TestCase):
|
||||
def test_pandas_label(self):
|
||||
# label must be a single column
|
||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
|
||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
||||
None, None, 'label', 'float')
|
||||
|
||||
# label must be supported dtype
|
||||
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)
|
||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
||||
None, None, 'label', 'float')
|
||||
|
||||
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||
result = xgb.core._maybe_pandas_label(df)
|
||||
result, _, _ = xgb.core._maybe_pandas_data(df, None, None,
|
||||
'label', 'float')
|
||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||
dtype=float))
|
||||
|
||||
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
def test_pandas_weight(self):
|
||||
kRows = 32
|
||||
kCols = 8
|
||||
|
||||
X = np.random.randn(kRows, kCols)
|
||||
y = np.random.randn(kRows)
|
||||
w = np.random.randn(kRows).astype(np.float32)
|
||||
w_pd = pd.DataFrame(w)
|
||||
data = xgb.DMatrix(X, y, w_pd)
|
||||
|
||||
assert data.num_row() == kRows
|
||||
assert data.num_col() == kCols
|
||||
|
||||
np.testing.assert_array_equal(data.get_weight(), w)
|
||||
|
||||
def test_cv_as_pandas(self):
|
||||
dm = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
params = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
|
||||
@@ -97,6 +97,7 @@ def test_ranking():
|
||||
valid_data = xgb.DMatrix(x_valid, y_valid)
|
||||
test_data = xgb.DMatrix(x_test)
|
||||
train_data.set_group(train_group)
|
||||
assert train_data.get_label().shape[0] == x_train.shape[0]
|
||||
valid_data.set_group(valid_group)
|
||||
|
||||
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
|
||||
Reference in New Issue
Block a user