Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan
2022-09-20 20:53:54 +08:00
committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
42 changed files with 999 additions and 343 deletions

View File

@@ -102,34 +102,38 @@ def run_scikit_model_check(name, path):
@pytest.mark.skipif(**tm.no_sklearn())
def test_model_compatibility():
'''Test model compatibility, can only be run on CI as others don't
"""Test model compatibility, can only be run on CI as others don't
have the credentials.
'''
"""
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'models')
path = os.path.join(path, "models")
zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' +
'.amazonaws.com/xgboost_model_compatibility_test.zip')
with zipfile.ZipFile(zip_path, 'r') as z:
z.extractall(path)
if not os.path.exists(path):
zip_path, _ = urllib.request.urlretrieve(
"https://xgboost-ci-jenkins-artifacts.s3-us-west-2"
+ ".amazonaws.com/xgboost_model_compatibility_test.zip"
)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(path)
models = [
os.path.join(root, f) for root, subdir, files in os.walk(path)
os.path.join(root, f)
for root, subdir, files in os.walk(path)
for f in files
if f != 'version'
if f != "version"
]
assert models
for path in models:
name = os.path.basename(path)
if name.startswith('xgboost-'):
if name.startswith("xgboost-"):
booster = xgboost.Booster(model_file=path)
run_booster_check(booster, name)
# Do full serialization.
booster = copy.copy(booster)
run_booster_check(booster, name)
elif name.startswith('xgboost_scikit'):
elif name.startswith("xgboost_scikit"):
run_scikit_model_check(name, path)
else:
assert False