Support pandas nullable types. (#7760)

This commit is contained in:
Jiaming Yuan
2022-03-30 08:51:52 +08:00
committed by GitHub
parent d4796482b5
commit 9150fdbd4d
2 changed files with 66 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
import xgboost as xgb
import testing as tm
@@ -293,3 +294,39 @@ class TestPandas:
assert isinstance(params, list)
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
def test_nullable_type(self):
y = np.random.default_rng(0).random(4)
def to_bytes(Xy: xgb.DMatrix) -> bytes:
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "Xy.dmatrix")
Xy.save_binary(path)
with open(path, "rb") as fd:
result = fd.read()
return result
def test_int(dtype) -> bytes:
arr = pd.DataFrame(
{"f0": [1, 2, None, 3], "f1": [4, 3, None, 1]}, dtype=dtype
)
Xy = xgb.DMatrix(arr, y)
Xy.feature_types = None
return to_bytes(Xy)
b0 = test_int(np.float32)
b1 = test_int(pd.Int16Dtype())
assert b0 == b1
def test_bool(dtype) -> bytes:
arr = pd.DataFrame(
{"f0": [True, False, None, True], "f1": [False, True, None, True]},
dtype=dtype,
)
Xy = xgb.DMatrix(arr, y)
Xy.feature_types = None
return to_bytes(Xy)
b0 = test_bool(pd.BooleanDtype())
b1 = test_bool(np.bool)
assert b0 != b1 # None is converted to False with np.bool