Calculate base_score based on input labels for mae. (#8107)
Fit an intercept as base score for abs loss.
This commit is contained in:
@@ -102,34 +102,38 @@ def run_scikit_model_check(name, path):
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_model_compatibility():
|
||||
'''Test model compatibility, can only be run on CI as others don't
|
||||
"""Test model compatibility, can only be run on CI as others don't
|
||||
have the credentials.
|
||||
|
||||
'''
|
||||
"""
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
path = os.path.join(path, 'models')
|
||||
path = os.path.join(path, "models")
|
||||
|
||||
zip_path, _ = urllib.request.urlretrieve('https://xgboost-ci-jenkins-artifacts.s3-us-west-2' +
|
||||
'.amazonaws.com/xgboost_model_compatibility_test.zip')
|
||||
with zipfile.ZipFile(zip_path, 'r') as z:
|
||||
z.extractall(path)
|
||||
if not os.path.exists(path):
|
||||
zip_path, _ = urllib.request.urlretrieve(
|
||||
"https://xgboost-ci-jenkins-artifacts.s3-us-west-2"
|
||||
+ ".amazonaws.com/xgboost_model_compatibility_test.zip"
|
||||
)
|
||||
with zipfile.ZipFile(zip_path, "r") as z:
|
||||
z.extractall(path)
|
||||
|
||||
models = [
|
||||
os.path.join(root, f) for root, subdir, files in os.walk(path)
|
||||
os.path.join(root, f)
|
||||
for root, subdir, files in os.walk(path)
|
||||
for f in files
|
||||
if f != 'version'
|
||||
if f != "version"
|
||||
]
|
||||
assert models
|
||||
|
||||
for path in models:
|
||||
name = os.path.basename(path)
|
||||
if name.startswith('xgboost-'):
|
||||
if name.startswith("xgboost-"):
|
||||
booster = xgboost.Booster(model_file=path)
|
||||
run_booster_check(booster, name)
|
||||
# Do full serialization.
|
||||
booster = copy.copy(booster)
|
||||
run_booster_check(booster, name)
|
||||
elif name.startswith('xgboost_scikit'):
|
||||
elif name.startswith("xgboost_scikit"):
|
||||
run_scikit_model_check(name, path)
|
||||
else:
|
||||
assert False
|
||||
|
||||
Reference in New Issue
Block a user