[Python] Require black and isort for new Python files. (#8096)

* [Python] Require black and isort for new Python files.

- Require black and isort for spark and dask module.

These files are relatively new and are more conform to the black formatter. We will
convert the rest of the library as we move forward.

Other libraries including dask/distributed and optuna use the same formatting style and
have a more strict standard. The black formatter is indeed quite nice, automating it can
help us unify the code style.

- Gather Python checks into a single script.
This commit is contained in:
Jiaming Yuan 2022-07-20 10:25:24 +08:00 committed by GitHub
parent f23cc92130
commit 8bdea72688
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 226 additions and 126 deletions

View File

@ -146,24 +146,6 @@ jobs:
run: |
make lint
mypy:
runs-on: ubuntu-latest
name: Type checking for Python
steps:
- uses: actions/checkout@v2
with:
submodules: 'true'
- uses: actions/setup-python@v2
with:
python-version: "3.8"
architecture: 'x64'
- name: Install Python packages
run: |
python -m pip install wheel setuptools mypy pandas dask[complete] distributed
- name: Run mypy
run: |
make mypy
doxygen:
runs-on: ubuntu-latest
name: Generate C/C++ API doc using Doxygen

View File

@ -3,6 +3,37 @@ name: XGBoost-Python-Tests
on: [push, pull_request]
jobs:
python-mypy-lint:
runs-on: ubuntu-latest
name: Type and format checks for the Python package
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
with:
submodules: 'true'
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: ${{ matrix.python-version }}
activate-environment: python_lint
environment-file: tests/ci_build/conda_env/python_lint.yml
- name: Display Conda env
shell: bash -l {0}
run: |
conda info
conda list
- name: Run mypy
shell: bash -l {0}
run: |
python tests/ci_build/lint_python.py --format=0 --type-check=1
- name: Run formatter
shell: bash -l {0}
run: |
python tests/ci_build/lint_python.py --format=1 --type-check=0
python-sdist-test:
runs-on: ${{ matrix.os }}
name: Test installing XGBoost Python source package on ${{ matrix.os }}

View File

@ -87,19 +87,6 @@ cover: check
endif
# dask is required to pass, others are not
# If any of the dask tests failed, contributor won't see the other error.
mypy:
cd python-package; \
mypy . && \
mypy ../demo/guide-python/external_memory.py && \
mypy ../demo/guide-python/categorical.py && \
mypy ../demo/guide-python/cat_in_the_dat.py && \
mypy ../tests/python-gpu/test_gpu_with_dask.py && \
mypy ../tests/python/test_data_iterator.py && \
mypy ../tests/python-gpu/test_gpu_data_iterator.py || exit 1; \
mypy . || true ;
clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o #xgboost
$(RM) -rf build_tests *.gcov tests/cpp/xgboost_test

View File

@ -31,49 +31,76 @@ Optional dask configuration
dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})
"""
import platform
import logging
import collections
import logging
import platform
import socket
from contextlib import contextmanager
from collections import defaultdict
from threading import Thread
from contextlib import contextmanager
from functools import partial, update_wrapper
from typing import TYPE_CHECKING, List, Tuple, Callable, Optional, Any, Union, Dict, Set
from typing import Sequence
from typing import Awaitable, Generator, TypeVar
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import numpy
from . import rabit, config
from .callback import TrainingCallback
from .compat import LazyLoader
from .compat import scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import lazy_isinstance
from . import config, rabit
from ._typing import FeatureNames, FeatureTypes
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .core import Objective, Metric
from .core import _deprecate_positional_args, _has_categorical
from .training import train as worker_train
from .callback import TrainingCallback
from .compat import (
PANDAS_INSTALLED,
DataFrame,
LazyLoader,
Series,
lazy_isinstance,
pandas_concat,
scipy_sparse,
)
from .core import (
Booster,
DataIter,
DeviceQuantileDMatrix,
DMatrix,
Metric,
Objective,
_deprecate_positional_args,
_expect,
_has_categorical,
)
from .sklearn import (
XGBClassifier,
XGBClassifierBase,
XGBModel,
XGBRanker,
XGBRankerMixIn,
XGBRegressorBase,
_check_rf_callback,
_cls_predict_proba,
_objective_decorator,
_wrap_evaluation_matrices,
xgboost_model_doc,
)
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
from .sklearn import _wrap_evaluation_matrices, _objective_decorator, _check_rf_callback
from .sklearn import XGBRankerMixIn
from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker
from .training import train as worker_train
if TYPE_CHECKING:
from dask import dataframe as dd
from dask import array as da
import dask
import distributed
from dask import array as da
from dask import dataframe as dd
else:
dd = LazyLoader("dd", globals(), "dask.dataframe")
da = LazyLoader("da", globals(), "dask.array")

View File

@ -4,23 +4,21 @@
# pylint: disable=too-few-public-methods
import numpy as np
import pandas as pd
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml import Estimator, Model
from pyspark.ml.functions import array_to_vector, vector_to_array
from pyspark.ml.linalg import VectorUDT
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import (
HasFeaturesCol,
HasLabelCol,
HasWeightCol,
HasPredictionCol,
HasProbabilityCol,
HasRawPredictionCol,
HasValidationIndicatorCol,
HasWeightCol,
)
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.util import MLReadable, MLWritable
from pyspark.sql.functions import col, pandas_udf, countDistinct, struct
from pyspark.sql.functions import col, countDistinct, pandas_udf, struct
from pyspark.sql.types import (
ArrayType,
DoubleType,
@ -29,35 +27,32 @@ from pyspark.sql.types import (
LongType,
ShortType,
)
import xgboost
from xgboost import XGBClassifier, XGBRegressor
from scipy.special import expit, softmax # pylint: disable=no-name-in-module
from xgboost.core import Booster
from xgboost.training import train as worker_train
from .data import (
_convert_partition_data_to_dmatrix,
)
import xgboost
from xgboost import XGBClassifier, XGBRegressor
from .data import _convert_partition_data_to_dmatrix
from .model import (
SparkXGBReader,
SparkXGBWriter,
SparkXGBModelReader,
SparkXGBModelWriter,
SparkXGBReader,
SparkXGBWriter,
)
from .params import HasArbitraryParamsDict, HasBaseMarginCol
from .utils import (
get_logger, _get_max_num_concurrent_tasks,
_get_default_params_from_func,
get_class_name,
RabitContext,
_get_rabit_args,
_get_args_from_message_list,
_get_default_params_from_func,
_get_gpu_id,
_get_max_num_concurrent_tasks,
_get_rabit_args,
_get_spark_session,
_is_local,
_get_gpu_id,
)
from .params import (
HasArbitraryParamsDict,
HasBaseMarginCol,
get_class_name,
get_logger,
)
# Put pyspark specific params here, they won't be passed to XGBoost.
@ -282,10 +277,7 @@ class _SparkXGBParams(
.get("spark.task.resource.gpu.amount")
)
is_local = _is_local(
_get_spark_session()
.sparkContext
)
is_local = _is_local(_get_spark_session().sparkContext)
if is_local:
# checking spark local mode.
@ -301,7 +293,8 @@ class _SparkXGBParams(
# gpu numbers and raising the exception.
get_logger(self.__class__.__name__).warning(
"You enabled use_gpu in spark local mode. Please make sure your local node "
"has at least %d GPUs", self.getOrDefault(self.num_workers)
"has at least %d GPUs",
self.getOrDefault(self.num_workers),
)
else:
# checking spark non-local mode.
@ -316,7 +309,7 @@ class _SparkXGBParams(
get_logger(self.__class__.__name__).warning(
"You configured %s GPU cores for each spark task, but in "
"XGBoost training, every Spark task will only use one GPU core.",
gpu_per_task
gpu_per_task,
)
@ -545,7 +538,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
"training is greater than current max number of concurrent "
"spark task slots, you need wait until more task slots available "
"or you need increase spark cluster workers.",
num_workers
num_workers,
)
if self._repartition_needed(dataset):
@ -568,10 +561,7 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)
is_local = _is_local(
_get_spark_session()
.sparkContext
)
is_local = _is_local(_get_spark_session().sparkContext)
def _train_booster(pandas_df_iter):
"""
@ -584,8 +574,9 @@ class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
context.barrier()
if use_gpu:
booster_params["gpu_id"] = context.partitionId() if is_local \
else _get_gpu_id(context)
booster_params["gpu_id"] = (
context.partitionId() if is_local else _get_gpu_id(context)
)
_rabit_args = ""
if context.partitionId() == 0:

View File

@ -2,8 +2,10 @@
"""Xgboost pyspark integration submodule for data related functions."""
# pylint: disable=too-many-arguments
from typing import Iterator
import numpy as np
import pandas as pd
from xgboost import DMatrix

View File

@ -2,12 +2,14 @@
"""Xgboost pyspark integration submodule for estimator API."""
# pylint: disable=too-many-ancestors
from pyspark.ml.param.shared import HasProbabilityCol, HasRawPredictionCol
from xgboost import XGBClassifier, XGBRegressor
from .core import (
_SparkXGBEstimator,
SparkXGBClassifierModel,
SparkXGBRegressorModel,
_set_pyspark_xgb_cls_param_attrs,
_SparkXGBEstimator,
)

View File

@ -5,13 +5,12 @@ import base64
import os
import uuid
from pyspark import cloudpickle
from pyspark import SparkFiles
from pyspark.sql import SparkSession
from pyspark import SparkFiles, cloudpickle
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
from pyspark.sql import SparkSession
from xgboost.core import Booster
from .utils import get_logger, get_class_name
from .utils import get_class_name, get_logger
def _get_or_create_tmp_dir():
@ -250,7 +249,9 @@ class SparkXGBModelReader(MLReader):
self.cls, path, self.sc, self.logger
)
xgb_sklearn_params = py_model._gen_xgb_params_dict(gen_xgb_sklearn_estimator_param=True)
xgb_sklearn_params = py_model._gen_xgb_params_dict(
gen_xgb_sklearn_estimator_param=True
)
model_load_path = os.path.join(path, "model.json")
ser_xgb_model = (
@ -263,8 +264,6 @@ class SparkXGBModelReader(MLReader):
def create_xgb_model():
return self.cls._xgb_cls()(**xgb_sklearn_params)
xgb_model = deserialize_xgb_model(
ser_xgb_model, create_xgb_model
)
xgb_model = deserialize_xgb_model(ser_xgb_model, create_xgb_model)
py_model._xgb_sklearn_model = xgb_model
return py_model

View File

@ -1,15 +1,15 @@
# type: ignore
"""Xgboost pyspark integration submodule for helper functions."""
import inspect
from threading import Thread
import sys
import logging
import sys
from threading import Thread
import pyspark
from pyspark.sql.session import SparkSession
from xgboost.tracker import RabitTracker
from xgboost import rabit
from xgboost.tracker import RabitTracker
def get_class_name(cls):
@ -140,13 +140,11 @@ def _get_gpu_id(task_context) -> int:
"""Get the gpu id from the task resources"""
if task_context is None:
# This is a safety check.
raise RuntimeError(
"_get_gpu_id should not be invoked from driver side."
)
raise RuntimeError("_get_gpu_id should not be invoked from driver side.")
resources = task_context.resources()
if 'gpu' not in resources:
if "gpu" not in resources:
raise RuntimeError(
"Couldn't get the gpu id, Please check the GPU resource configuration"
)
# return the first gpu id.
return int(resources['gpu'].addresses[0].strip())
return int(resources["gpu"].addresses[0].strip())

View File

@ -0,0 +1,15 @@
name: python_lint
channels:
- conda-forge
dependencies:
- python=3.8
- wheel
- setuptools
- mypy
- numpy
- scipy
- pandas
- dask
- distributed
- black
- isort

View File

@ -0,0 +1,63 @@
import argparse
import os
import subprocess
import sys
from test_utils import DirectoryExcursion
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
PROJECT_ROOT = os.path.normpath(os.path.join(CURDIR, os.path.pardir, os.path.pardir))
def run_formatter(rel_path: str):
path = os.path.join(PROJECT_ROOT, rel_path)
isort_ret = subprocess.run(["isort", "--check", "--profile=black", path]).returncode
black_ret = subprocess.run(
["black", "--check", "./python-package/xgboost/dask.py"]
).returncode
if isort_ret != 0 or black_ret != 0:
msg = (
"Please run the following command on your machine to address the format"
f" errors:\n isort --check --profile=black {rel_path}\n black {rel_path}\n"
)
print(msg, file=sys.stdout)
return False
return True
def run_mypy(rel_path: str):
with DirectoryExcursion(os.path.join(PROJECT_ROOT, "python-package")):
path = os.path.join(PROJECT_ROOT, rel_path)
ret = subprocess.run(["mypy", path])
if ret.returncode != 0:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--format", type=int, choices=[0, 1], default=1)
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
args = parser.parse_args()
if args.format == 1:
if not all(
[
run_formatter("python-package/xgboost/dask.py"),
run_formatter("python-package/xgboost/spark"),
]
):
sys.exit(-1)
if args.type_check == 1:
if not all(
run_mypy(path)
for path in [
"python-package/xgboost/",
"demo/guide-python/external_memory.py",
"demo/guide-python/cat_in_the_dat.py",
"tests/python/test_data_iterator.py",
"tests/python-gpu/test_gpu_with_dask.py",
"tests/python-gpu/test_gpu_data_iterator.py",
]
):
sys.exit(-1)

View File

@ -1,6 +1,7 @@
import argparse
import os
import subprocess
from test_utils import DirectoryExcursion
ROOT = os.path.normpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir,
@ -8,18 +9,6 @@ ROOT = os.path.normpath(
r_package = os.path.join(ROOT, 'R-package')
class DirectoryExcursion:
def __init__(self, path: os.PathLike):
self.path = path
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
def __enter__(self):
os.chdir(self.path)
def __exit__(self, *args):
os.chdir(self.curdir)
def get_mingw_bin():
return os.path.join('c:/rtools40/mingw64/', 'bin')

View File

@ -0,0 +1,14 @@
import os
from typing import Union
class DirectoryExcursion:
def __init__(self, path: Union[os.PathLike, str]):
self.path = path
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
def __enter__(self):
os.chdir(self.path)
def __exit__(self, *args):
os.chdir(self.curdir)