Save model in ubj as the default. (#9947)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import csv
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -24,20 +25,18 @@ class TestDMatrix:
|
||||
with pytest.warns(UserWarning):
|
||||
data._warn_unused_missing("uri", 4)
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
data._warn_unused_missing("uri", None)
|
||||
data._warn_unused_missing("uri", np.nan)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
x = rng.randn(10, 10)
|
||||
y = rng.randn(10)
|
||||
|
||||
xgb.DMatrix(x, y, missing=4)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
@@ -264,7 +263,7 @@ class TestDMatrix:
|
||||
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
||||
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
||||
watchlist = [(dtrain, "train")]
|
||||
param = {"max_depth": 3, "objective": "binary:logistic", "verbosity": 0}
|
||||
param = {"max_depth": 3, "objective": "binary:logistic"}
|
||||
bst = xgb.train(param, dtrain, 5, watchlist)
|
||||
bst.predict(dtrain)
|
||||
|
||||
@@ -302,7 +301,7 @@ class TestDMatrix:
|
||||
dtrain = xgb.DMatrix(x, label=rng.binomial(1, 0.3, nrow))
|
||||
assert (dtrain.num_row(), dtrain.num_col()) == (nrow, ncol)
|
||||
watchlist = [(dtrain, "train")]
|
||||
param = {"max_depth": 3, "objective": "binary:logistic", "verbosity": 0}
|
||||
param = {"max_depth": 3, "objective": "binary:logistic"}
|
||||
bst = xgb.train(param, dtrain, 5, watchlist)
|
||||
bst.predict(dtrain)
|
||||
|
||||
@@ -475,17 +474,19 @@ class TestDMatrixColumnSplit:
|
||||
def test_uri(self):
|
||||
def verify_uri():
|
||||
rank = xgb.collective.get_rank()
|
||||
data = np.random.rand(5, 5)
|
||||
filename = f"test_data_{rank}.csv"
|
||||
with open(filename, mode="w", newline="") as file:
|
||||
writer = csv.writer(file)
|
||||
for row in data:
|
||||
writer.writerow(row)
|
||||
dtrain = xgb.DMatrix(
|
||||
f"{filename}?format=csv", data_split_mode=DataSplitMode.COL
|
||||
)
|
||||
assert dtrain.num_row() == 5
|
||||
assert dtrain.num_col() == 5 * xgb.collective.get_world_size()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
filename = os.path.join(tmpdir, f"test_data_{rank}.csv")
|
||||
|
||||
data = np.random.rand(5, 5)
|
||||
with open(filename, mode="w", newline="") as file:
|
||||
writer = csv.writer(file)
|
||||
for row in data:
|
||||
writer.writerow(row)
|
||||
dtrain = xgb.DMatrix(
|
||||
f"{filename}?format=csv", data_split_mode=DataSplitMode.COL
|
||||
)
|
||||
assert dtrain.num_row() == 5
|
||||
assert dtrain.num_col() == 5 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_uri)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user