Demo of federated learning using NVFlare (#7879)
Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Distributed XGBoost Rabit related API."""
|
||||
import ctypes
|
||||
from enum import IntEnum, unique
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Any, TypeVar, Callable, Optional, cast, List, Union
|
||||
|
||||
@@ -8,6 +9,8 @@ import numpy as np
|
||||
|
||||
from .core import _LIB, c_str, _check_call
|
||||
|
||||
LOGGER = logging.getLogger("[xgboost.rabit]")
|
||||
|
||||
|
||||
def _init_rabit() -> None:
|
||||
"""internal library initializer."""
|
||||
@@ -224,5 +227,21 @@ def version_number() -> int:
|
||||
return ret
|
||||
|
||||
|
||||
class RabitContext:
|
||||
"""A context controlling rabit initialization and finalization."""
|
||||
|
||||
def __init__(self, args: List[bytes]) -> None:
|
||||
self.args = args
|
||||
|
||||
def __enter__(self) -> None:
|
||||
init(self.args)
|
||||
assert is_distributed()
|
||||
LOGGER.debug("-------------- rabit say hello ------------------")
|
||||
|
||||
def __exit__(self, *args: List) -> None:
|
||||
finalize()
|
||||
LOGGER.debug("--------------- rabit say bye ------------------")
|
||||
|
||||
|
||||
# initialization script
|
||||
_init_rabit()
|
||||
|
||||
Reference in New Issue
Block a user