Allow import via python datatable. (#3272)
* Allow import via python datatable. * Write unit tests * Refactor dt API functions * Refactor python code * Lint fixes * Address review comments
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
eecf341ea7
commit
9ac163d0bb
@@ -3,6 +3,34 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
TEST(c_api, XGDMatrixCreateFromMatDT) {
|
||||
std::vector<int> col0 = {0, -1, 3};
|
||||
std::vector<float> col1 = {-4.0f, 2.0f, 0.0f};
|
||||
const char *col0_type = "int32";
|
||||
const char *col1_type = "float32";
|
||||
std::vector<void *> data = {col0.data(), col1.data()};
|
||||
std::vector<const char *> types = {col0_type, col1_type};
|
||||
DMatrixHandle handle;
|
||||
XGDMatrixCreateFromDT(data.data(), types.data(), 3, 2, &handle,
|
||||
0);
|
||||
std::shared_ptr<xgboost::DMatrix> dmat =
|
||||
*static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
xgboost::MetaInfo &info = dmat->Info();
|
||||
ASSERT_EQ(info.num_col_, 2);
|
||||
ASSERT_EQ(info.num_row_, 3);
|
||||
ASSERT_EQ(info.num_nonzero_, 6);
|
||||
|
||||
auto iter = dmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
auto batch = iter->Value();
|
||||
ASSERT_EQ(batch[0][0].fvalue, 0.0f);
|
||||
ASSERT_EQ(batch[0][1].fvalue, -4.0f);
|
||||
ASSERT_EQ(batch[2][0].fvalue, 3.0f);
|
||||
ASSERT_EQ(batch[2][1].fvalue, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(c_api, XGDMatrixCreateFromMat_omp) {
|
||||
std::vector<int> num_rows = {100, 11374, 15000};
|
||||
for (auto row : num_rows) {
|
||||
|
||||
47
tests/python/test_dt.py
Normal file
47
tests/python/test_dt.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
|
||||
import testing as tm
|
||||
import xgboost as xgb
|
||||
|
||||
try:
|
||||
import datatable as dt
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
tm._skip_if_no_dt()
|
||||
tm._skip_if_no_pandas()
|
||||
|
||||
|
||||
class TestDataTable(unittest.TestCase):
|
||||
|
||||
def test_dt(self):
|
||||
df = pd.DataFrame([[1, 2., True], [2, 3., False]], columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
labels = dt.Frame([1, 2])
|
||||
dm = xgb.DMatrix(dtable, label=labels)
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# overwrite feature_names
|
||||
dm = xgb.DMatrix(dtable, label=pd.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = pd.DataFrame([[1, 2., 'x'], [2, 3., 'y']], columns=['a', 'b', 'c'])
|
||||
dtable = dt.Frame(df)
|
||||
self.assertRaises(ValueError, xgb.DMatrix, dtable)
|
||||
|
||||
df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
dtable = dt.Frame(df)
|
||||
dm = xgb.DMatrix(dtable)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import nose
|
||||
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
|
||||
|
||||
|
||||
def _skip_if_no_sklearn():
|
||||
@@ -15,6 +15,11 @@ def _skip_if_no_pandas():
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_dt():
|
||||
if not DT_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_matplotlib():
|
||||
try:
|
||||
import matplotlib.pyplot as _ # noqa
|
||||
|
||||
@@ -48,6 +48,13 @@ if [ ${TASK} == "python_test" ]; then
|
||||
source activate python3
|
||||
python --version
|
||||
conda install numpy scipy pandas matplotlib nose scikit-learn
|
||||
|
||||
# Install data table from source
|
||||
wget http://releases.llvm.org/5.0.2/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
|
||||
tar xf clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04.tar.xz
|
||||
export LLVM5=$(pwd)/clang+llvm-5.0.2-x86_64-linux-gnu-ubuntu-14.04
|
||||
python -m pip install datatable --no-binary datatable
|
||||
|
||||
python -m pip install graphviz pytest pytest-cov codecov
|
||||
python -m nose tests/python || exit -1
|
||||
py.test tests/python --cov=python-package/xgboost
|
||||
|
||||
Reference in New Issue
Block a user