Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -6,6 +7,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import DataSplitMode
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
@@ -97,3 +99,17 @@ class TestArrowTable:
|
||||
y_np_low = dtrain.get_float_info("label_lower_bound")
|
||||
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
|
||||
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)
|
||||
|
||||
|
||||
class TestArrowTableColumnSplit:
|
||||
def test_arrow_table(self):
|
||||
def verify_arrow_table():
|
||||
df = pd.DataFrame(
|
||||
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
|
||||
)
|
||||
table = pa.Table.from_pandas(df)
|
||||
dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 4 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)
|
||||
|
||||
Reference in New Issue
Block a user