committed by
GitHub
parent
8aaabce7c9
commit
751160b69c
@@ -12,25 +12,15 @@ rng = np.random.RandomState(1994)
|
||||
class TestGPUBasicModels(unittest.TestCase):
|
||||
cputest = test_bm.TestModels()
|
||||
|
||||
def test_eta_decay_gpu_hist(self):
|
||||
self.cputest.run_eta_decay('gpu_hist')
|
||||
|
||||
def test_deterministic_gpu_hist(self):
|
||||
kRows = 1000
|
||||
kCols = 64
|
||||
kClasses = 4
|
||||
# Create large values to force rounding.
|
||||
X = np.random.randn(kRows, kCols) * 1e4
|
||||
y = np.random.randint(0, kClasses, size=kRows)
|
||||
|
||||
def run_cls(self, X, y, deterministic):
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist',
|
||||
deterministic_histogram=True,
|
||||
deterministic_histogram=deterministic,
|
||||
single_precision_histogram=True)
|
||||
cls.fit(X, y)
|
||||
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')
|
||||
|
||||
cls = xgb.XGBClassifier(tree_method='gpu_hist',
|
||||
deterministic_histogram=True,
|
||||
deterministic_histogram=deterministic,
|
||||
single_precision_histogram=True)
|
||||
cls.fit(X, y)
|
||||
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
|
||||
@@ -40,7 +30,24 @@ class TestGPUBasicModels(unittest.TestCase):
|
||||
with open('test_deterministic_gpu_hist-1.json', 'r') as fd:
|
||||
model_1 = fd.read()
|
||||
|
||||
assert hash(model_0) == hash(model_1)
|
||||
|
||||
os.remove('test_deterministic_gpu_hist-0.json')
|
||||
os.remove('test_deterministic_gpu_hist-1.json')
|
||||
|
||||
return hash(model_0), hash(model_1)
|
||||
|
||||
def test_eta_decay_gpu_hist(self):
|
||||
self.cputest.run_eta_decay('gpu_hist')
|
||||
|
||||
def test_deterministic_gpu_hist(self):
|
||||
kRows = 1000
|
||||
kCols = 64
|
||||
kClasses = 4
|
||||
# Create large values to force rounding.
|
||||
X = np.random.randn(kRows, kCols) * 1e4
|
||||
y = np.random.randint(0, kClasses, size=kRows) * 1e4
|
||||
|
||||
model_0, model_1 = self.run_cls(X, y, True)
|
||||
assert model_0 == model_1
|
||||
|
||||
model_0, model_1 = self.run_cls(X, y, False)
|
||||
assert model_0 != model_1
|
||||
|
||||
Reference in New Issue
Block a user