Fix DMatrix feature names/types IO. (#6507)
* Fix feature names/types IO Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import scipy.sparse
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
import testing as tm
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
dpath = 'demo/data/'
|
||||
@@ -207,6 +211,23 @@ class TestDMatrix:
|
||||
with pytest.raises(ValueError):
|
||||
bst.predict(dm)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_save_binary(self):
|
||||
import pandas as pd
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'm.dmatrix')
|
||||
data = pd.DataFrame({
|
||||
"a": [0, 1],
|
||||
"b": [2, 3],
|
||||
"c": [4, 5]
|
||||
})
|
||||
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
|
||||
assert m0.feature_names == ['a', 'b']
|
||||
m0.save_binary(path)
|
||||
m1 = xgb.DMatrix(path)
|
||||
assert m0.feature_names == m1.feature_names
|
||||
assert m0.feature_types == m1.feature_types
|
||||
|
||||
def test_get_info(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtrain.get_float_info('label')
|
||||
|
||||
Reference in New Issue
Block a user