Unify the hist tree method for different devices. (#9363)
This commit is contained in:
@@ -57,12 +57,12 @@ class TestTrainingContinuation:
|
||||
|
||||
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=0)
|
||||
gbdt_02.save_model('xgb_tc.model')
|
||||
gbdt_02.save_model('xgb_tc.json')
|
||||
|
||||
gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model=gbdt_02)
|
||||
gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model="xgb_tc.model")
|
||||
num_boost_round=10, xgb_model="xgb_tc.json")
|
||||
ntrees_02a = len(gbdt_02a.get_dump())
|
||||
ntrees_02b = len(gbdt_02b.get_dump())
|
||||
assert ntrees_02a == 10
|
||||
@@ -78,18 +78,18 @@ class TestTrainingContinuation:
|
||||
|
||||
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=3)
|
||||
gbdt_03.save_model('xgb_tc.model')
|
||||
gbdt_03.save_model('xgb_tc.json')
|
||||
|
||||
gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model=gbdt_03)
|
||||
gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model="xgb_tc.model")
|
||||
num_boost_round=7, xgb_model="xgb_tc.json")
|
||||
ntrees_03a = len(gbdt_03a.get_dump())
|
||||
ntrees_03b = len(gbdt_03b.get_dump())
|
||||
assert ntrees_03a == 10
|
||||
assert ntrees_03b == 10
|
||||
|
||||
os.remove('xgb_tc.model')
|
||||
os.remove('xgb_tc.json')
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
|
||||
Reference in New Issue
Block a user