Remove extra sync for dense data (#7120)
Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
This commit is contained in:
parent
e6088366df
commit
caa9e527dd
@ -329,8 +329,8 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
for (const size_t *it = e.begin; it < e.end; ++it) {
|
for (const size_t *it = e.begin; it < e.end; ++it) {
|
||||||
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess());
|
||||||
}
|
}
|
||||||
|
histred_.Allreduce(&grad_stat, 1);
|
||||||
}
|
}
|
||||||
histred_.Allreduce(&grad_stat, 1);
|
|
||||||
|
|
||||||
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
auto weight = evaluator_->InitRoot(GradStats{grad_stat});
|
||||||
p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess();
|
p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess();
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from hypothesis import given, settings, note, HealthCheck
|
|||||||
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
||||||
from test_with_sklearn import run_feature_weights, run_data_initialization
|
from test_with_sklearn import run_feature_weights, run_data_initialization
|
||||||
from test_predict import verify_leaf_output
|
from test_predict import verify_leaf_output
|
||||||
|
from sklearn.datasets import make_regression
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -1493,6 +1494,37 @@ def test_parallel_submits(client: "Client") -> None:
|
|||||||
for i, cls in enumerate(classifiers):
|
for i, cls in enumerate(classifiers):
|
||||||
assert cls.get_booster().num_boosted_rounds() == i + 1
|
assert cls.get_booster().num_boosted_rounds() == i + 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||||
|
def test_hist_root_stats_with_different_num_worker(tree_method: str) -> None:
|
||||||
|
"""assert that different workers count dosn't affect summ statistic's on root"""
|
||||||
|
def dask_train(n_workers, X, y, num_obs, num_features):
|
||||||
|
cluster = LocalCluster(n_workers=n_workers)
|
||||||
|
client = Client(cluster)
|
||||||
|
|
||||||
|
chunk_size = num_obs/n_workers
|
||||||
|
X = da.from_array(X, chunks=(chunk_size, num_features))
|
||||||
|
y = da.from_array(y.reshape(num_obs,1), chunks=(chunk_size, 1))
|
||||||
|
dtrain = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
|
||||||
|
output = xgb.dask.train(
|
||||||
|
client,
|
||||||
|
{"verbosity": 0, "tree_method": tree_method, "objective": "reg:squarederror", 'max_depth': 2},
|
||||||
|
dtrain,
|
||||||
|
num_boost_round=1
|
||||||
|
)
|
||||||
|
dump_model = output['booster'].get_dump(with_stats=True)
|
||||||
|
client.shutdown()
|
||||||
|
return dump_model
|
||||||
|
|
||||||
|
num_obs = 1000
|
||||||
|
num_features = 10
|
||||||
|
X, y = make_regression(num_obs, num_features, random_state=777)
|
||||||
|
first_model = dask_train(1, X, y, num_obs, num_features)[0]
|
||||||
|
second_model = dask_train(2, X, y, num_obs, num_features)[0]
|
||||||
|
first_summ_stats = first_model[first_model.find('cover='):first_model.find('\n')]
|
||||||
|
second_summ_stats = second_model[second_model.find('cover='):second_model.find('\n')]
|
||||||
|
assert first_summ_stats == second_summ_stats
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_submit_multi_clients() -> None:
|
def test_parallel_submit_multi_clients() -> None:
|
||||||
"""Test for running multiple train simultaneously from multiple clients."""
|
"""Test for running multiple train simultaneously from multiple clients."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user