More in-memory input support for column split (#9685)

This commit is contained in:
Rong Ou
2023-10-20 01:02:36 -07:00
committed by GitHub
parent 83cdf14b2c
commit 6fbe6248f4
5 changed files with 479 additions and 213 deletions

View File

@@ -8,6 +8,7 @@ import importlib.util
import multiprocessing
import os
import platform
import queue
import socket
import sys
import threading
@@ -942,13 +943,20 @@ def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir))
def run_with_rabit(world_size: int, test_fn: Callable) -> None:
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
def run_with_rabit(
world_size: int, test_fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
exception_queue: queue.Queue = queue.Queue()
def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None:
with xgb.collective.CommunicatorContext(**rabit_env):
test_fn()
try:
with xgb.collective.CommunicatorContext(**rabit_env):
test_fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
exception_queue.put(e)
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
workers = []
for _ in range(world_size):
@@ -957,5 +965,20 @@ def run_with_rabit(world_size: int, test_fn: Callable) -> None:
worker.start()
for worker in workers:
worker.join()
assert exception_queue.empty(), f"Worker failed: {exception_queue.get()}"
tracker.join()
def column_split_feature_names(
feature_names: List[Union[str, int]], world_size: int
) -> List[str]:
"""Get the global list of feature names from the local feature names."""
return [
f"{rank}.{feature}" for rank in range(world_size) for feature in feature_names
]
def is_windows() -> bool:
"""Check if the current platform is Windows."""
return platform.system() == "Windows"