Fix DMatrix feature names/types IO. (#6507)

* Fix feature names/types IO

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-12-16 14:24:27 +08:00
committed by GitHub
parent 886486a519
commit ef4a0e0aac
2 changed files with 42 additions and 16 deletions

View File

@@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
import xgboost as xgb
import scipy.sparse
import pytest
from scipy.sparse import rand, csr_matrix
import testing as tm
rng = np.random.RandomState(1)
dpath = 'demo/data/'
@@ -207,6 +211,23 @@ class TestDMatrix:
with pytest.raises(ValueError):
bst.predict(dm)
@pytest.mark.skipif(**tm.no_pandas())
def test_save_binary(self):
import pandas as pd
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'm.dmatrix')
data = pd.DataFrame({
"a": [0, 1],
"b": [2, 3],
"c": [4, 5]
})
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
assert m0.feature_names == ['a', 'b']
m0.save_binary(path)
m1 = xgb.DMatrix(path)
assert m0.feature_names == m1.feature_names
assert m0.feature_types == m1.feature_types
def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain.get_float_info('label')