Unify the hist tree method for different devices. (#9363)

This commit is contained in:
Jiaming Yuan
2023-07-11 10:04:39 +08:00
committed by GitHub
parent 20c52f07d2
commit 97ed944209
8 changed files with 242 additions and 142 deletions

View File

@@ -57,12 +57,12 @@ class TestTrainingContinuation:
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=0)
gbdt_02.save_model('xgb_tc.model')
gbdt_02.save_model('xgb_tc.json')
gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model=gbdt_02)
gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=10, xgb_model="xgb_tc.model")
num_boost_round=10, xgb_model="xgb_tc.json")
ntrees_02a = len(gbdt_02a.get_dump())
ntrees_02b = len(gbdt_02b.get_dump())
assert ntrees_02a == 10
@@ -78,18 +78,18 @@ class TestTrainingContinuation:
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=3)
gbdt_03.save_model('xgb_tc.model')
gbdt_03.save_model('xgb_tc.json')
gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model=gbdt_03)
gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
num_boost_round=7, xgb_model="xgb_tc.model")
num_boost_round=7, xgb_model="xgb_tc.json")
ntrees_03a = len(gbdt_03a.get_dump())
ntrees_03b = len(gbdt_03b.get_dump())
assert ntrees_03a == 10
assert ntrees_03b == 10
os.remove('xgb_tc.model')
os.remove('xgb_tc.json')
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))